def main():
    parser = ArgumentParser()
    parser.add_argument(
        '--pregenerated_data',
        type=str,
        required=True,
        default='/nas/hebin/data/english-exp/books_wiki_tokens_ngrams')
    parser.add_argument('--s3_output_dir', type=str, default='huawei_yun')
    parser.add_argument('--student_model',
                        type=str,
                        default='8layer_bert',
                        required=True)
    parser.add_argument('--teacher_model', type=str, default='electra_base')
    parser.add_argument('--cache_dir', type=str, default='/cache', help='')

    parser.add_argument("--epochs",
                        type=int,
                        default=2,
                        help="Number of epochs to train for")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--learning_rate",
                        default=1e-4,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--max_seq_length", type=int, default=512)

    parser.add_argument("--do_lower_case", action="store_true")
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--scratch',
                        action='store_true',
                        help="Whether to train from scratch")
    parser.add_argument(
        "--reduce_memory",
        action="store_true",
        help=
        "Store training data as on-disc memmaps to massively reduce memory usage"
    )
    parser.add_argument('--debug',
                        action='store_true',
                        help="Whether to debug")

    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html",
    )

    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument("--already_trained_epoch", default=0, type=int)
    parser.add_argument(
        "--masked_lm_prob",
        type=float,
        default=0.0,
        help="Probability of masking each token for the LM task")
    parser.add_argument(
        "--max_predictions_per_seq",
        type=int,
        default=77,
        help="Maximum number of tokens to mask in each sequence")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--logging_steps",
                        type=int,
                        default=500,
                        help="Log every X updates steps.")
    parser.add_argument("--warmup_steps",
                        default=10000,
                        type=int,
                        help="Linear warmup over warmup_steps.")

    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")

    parser.add_argument("--num_workers",
                        type=int,
                        default=4,
                        help="num_workers.")
    parser.add_argument("--continue_index", type=int, default=0, help="")
    parser.add_argument("--threads",
                        type=int,
                        default=27,
                        help="Number of threads to preprocess input data")

    # Search space for sub_bart architecture
    parser.add_argument('--layer_num_space',
                        nargs='+',
                        type=int,
                        default=[1, 8])
    parser.add_argument('--hidden_size_space',
                        nargs='+',
                        type=int,
                        default=[128, 768])
    parser.add_argument('--qkv_size_space',
                        nargs='+',
                        type=int,
                        default=[180, 768])
    parser.add_argument('--intermediate_size_space',
                        nargs='+',
                        type=int,
                        default=[128, 3072])
    parser.add_argument('--head_num_space',
                        nargs='+',
                        type=int,
                        default=[1, 12])
    parser.add_argument('--sample_times_per_batch', type=int, default=1)
    parser.add_argument('--further_train', action='store_true')
    parser.add_argument('--mlm_loss', action='store_true')

    # Argument for Huawei yun
    parser.add_argument('--data_url', type=str, default='', help='s3 url')
    parser.add_argument("--train_url", type=str, default="", help="s3 url")

    args = parser.parse_args()

    assert (torch.cuda.is_available())
    device_count = torch.cuda.device_count()
    args.rank = int(os.getenv('RANK', '0'))
    args.world_size = int(os.getenv("WORLD_SIZE", '1'))

    # Call the init process
    # init_method = 'tcp://'
    init_method = ''
    master_ip = os.getenv('MASTER_ADDR', 'localhost')
    master_port = os.getenv('MASTER_PORT', '6000')
    init_method += master_ip + ':' + master_port

    # Manually set the device ids.
    # if device_count > 0:
    # args.local_rank = args.rank % device_count
    torch.cuda.set_device(args.local_rank)
    device = torch.device("cuda", args.local_rank)
    print('device_id: %s' % args.local_rank)
    print('device_count: %s, rank: %s, world_size: %s' %
          (device_count, args.rank, args.world_size))
    print(init_method)

    torch.distributed.init_process_group(backend='nccl',
                                         world_size=args.world_size,
                                         rank=args.rank,
                                         init_method=init_method)

    LOCAL_DIR = args.cache_dir
    if oncloud:
        assert mox.file.exists(LOCAL_DIR)

    if args.local_rank == 0 and oncloud:
        logging.info(
            mox.file.list_directory(args.pregenerated_data, recursive=True))
        logging.info(
            mox.file.list_directory(args.student_model, recursive=True))

    local_save_dir = os.path.join(LOCAL_DIR, 'output', 'superbert',
                                  'checkpoints')
    local_tsbd_dir = os.path.join(LOCAL_DIR, 'output', 'superbert',
                                  'tensorboard')
    save_name = '_'.join([
        'superbert',
        'epoch',
        str(args.epochs),
        'lr',
        str(args.learning_rate),
        'bsz',
        str(args.train_batch_size),
        'grad_accu',
        str(args.gradient_accumulation_steps),
        str(args.max_seq_length),
        'gpu',
        str(args.world_size),
    ])
    bash_save_dir = os.path.join(local_save_dir, save_name)
    bash_tsbd_dir = os.path.join(local_tsbd_dir, save_name)
    if args.local_rank == 0:
        if not os.path.exists(bash_save_dir):
            os.makedirs(bash_save_dir)
            logger.info(bash_save_dir + ' created!')
        if not os.path.exists(bash_tsbd_dir):
            os.makedirs(bash_tsbd_dir)
            logger.info(bash_tsbd_dir + ' created!')

    local_data_dir_tmp = '/cache/data/tmp/'
    local_data_dir = local_data_dir_tmp + save_name

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    args.tokenizer = BertTokenizer.from_pretrained(
        args.student_model, do_lower_case=args.do_lower_case)
    args.vocab_list = list(args.tokenizer.vocab.keys())

    config = BertConfig.from_pretrained(
        os.path.join(args.student_model, CONFIG_NAME))
    logger.info("Model config {}".format(config))

    if args.further_train:
        if args.mlm_loss:
            student_model = SuperBertForPreTraining.from_pretrained(
                args.student_model, config)
        else:
            student_model = SuperTinyBertForPreTraining.from_pretrained(
                args.student_model, config)
    else:
        if args.mlm_loss:
            student_model = SuperBertForPreTraining.from_scratch(
                args.student_model, config)
        else:
            student_model = SuperTinyBertForPreTraining.from_scratch(
                args.student_model, config)

    student_model.to(device)

    if not args.mlm_loss:
        teacher_model = BertModel.from_pretrained(args.teacher_model)
        teacher_model.to(device)

    # build arch space
    min_hidden_size, max_hidden_size = args.hidden_size_space
    min_ffn_size, max_ffn_size = args.intermediate_size_space
    min_qkv_size, max_qkv_size = args.qkv_size_space
    min_head_num, max_head_num = args.head_num_space

    hidden_step = 4
    ffn_step = 4
    qkv_step = 12
    head_step = 1

    number_hidden_step = int((max_hidden_size - min_hidden_size) / hidden_step)
    number_ffn_step = int((max_ffn_size - min_ffn_size) / ffn_step)
    number_qkv_step = int((max_qkv_size - min_qkv_size) / qkv_step)
    number_head_step = int((max_head_num - min_head_num) / head_step)

    layer_numbers = list(
        range(args.layer_num_space[0], args.layer_num_space[1] + 1))
    hidden_sizes = [
        i * hidden_step + min_hidden_size
        for i in range(number_hidden_step + 1)
    ]
    ffn_sizes = [
        i * ffn_step + min_ffn_size for i in range(number_ffn_step + 1)
    ]
    qkv_sizes = [
        i * qkv_step + min_qkv_size for i in range(number_qkv_step + 1)
    ]
    head_numbers = [
        i * head_step + min_head_num for i in range(number_head_step + 1)
    ]

    ######
    if args.local_rank == 0:
        tb_writer = SummaryWriter(bash_tsbd_dir)

    global_step = 0
    step = 0
    tr_loss, tr_rep_loss, tr_att_loss = 0.0, 0.0, 0.0
    logging_loss, rep_logging_loss, att_logging_loss = 0.0, 0.0, 0.0
    end_time, start_time = 0, 0

    submodel_config = dict()

    if args.further_train:
        submodel_config['sample_layer_num'] = config.num_hidden_layers
        submodel_config['sample_hidden_size'] = config.hidden_size
        submodel_config[
            'sample_intermediate_sizes'] = config.num_hidden_layers * [
                config.intermediate_size
            ]
        submodel_config[
            'sample_num_attention_heads'] = config.num_hidden_layers * [
                config.num_attention_heads
            ]
        submodel_config['sample_qkv_sizes'] = config.num_hidden_layers * [
            config.qkv_size
        ]

    for epoch in range(args.epochs):
        if epoch < args.continue_index:
            args.warmup_steps = 0
            continue

        args.local_data_dir = os.path.join(local_data_dir, str(epoch))
        if args.local_rank == 0:
            os.makedirs(args.local_data_dir)
        while 1:
            if os.path.exists(args.local_data_dir):
                epoch_dataset = load_doc_tokens_ngrams(args)
                break

        if args.local_rank == 0 and oncloud:
            logging.info('Dataset in epoch %s', epoch)
            logging.info(
                mox.file.list_directory(args.local_data_dir, recursive=True))

        train_sampler = DistributedSampler(epoch_dataset,
                                           num_replicas=1,
                                           rank=0)

        train_dataloader = DataLoader(epoch_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        step_in_each_epoch = len(
            train_dataloader) // args.gradient_accumulation_steps
        num_train_optimization_steps = step_in_each_epoch * args.epochs
        logging.info("***** Running training *****")
        logging.info("  Num examples = %d",
                     len(epoch_dataset) * args.world_size)
        logger.info("  Num Epochs = %d", args.epochs)
        logging.info(
            "  Total train batch size (w. parallel, distributed & accumulation) = %d",
            args.train_batch_size * args.gradient_accumulation_steps *
            args.world_size)
        logger.info("  Gradient Accumulation steps = %d",
                    args.gradient_accumulation_steps)
        logging.info("  Num steps = %d", num_train_optimization_steps)

        if epoch == args.continue_index:
            # Prepare optimizer
            param_optimizer = list(student_model.named_parameters())
            no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
            optimizer_grouped_parameters = [{
                'params': [
                    p for n, p in param_optimizer
                    if not any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.01
            }, {
                'params': [
                    p for n, p in param_optimizer
                    if any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.0
            }]

            warm_up_ratio = args.warmup_steps / num_train_optimization_steps
            print('warm_up_ratio: {}'.format(warm_up_ratio))
            optimizer = BertAdam(optimizer_grouped_parameters,
                                 lr=args.learning_rate,
                                 e=args.adam_epsilon,
                                 schedule='warmup_linear',
                                 t_total=num_train_optimization_steps,
                                 warmup=warm_up_ratio)

            if args.fp16:
                try:
                    from apex import amp
                except ImportError:
                    raise ImportError(
                        "Please install apex from https://www.github.com/nvidia/apex"
                        " to use fp16 training.")
                student_model, optimizer = amp.initialize(
                    student_model,
                    optimizer,
                    opt_level=args.fp16_opt_level,
                    min_loss_scale=1)  #

            # apex
            student_model = DDP(
                student_model,
                message_size=10000000,
                gradient_predivide_factor=torch.distributed.get_world_size(),
                delay_allreduce=True)

            if not args.mlm_loss:
                teacher_model = DDP(teacher_model,
                                    message_size=10000000,
                                    gradient_predivide_factor=torch.
                                    distributed.get_world_size(),
                                    delay_allreduce=True)
                teacher_model.eval()

            logger.info('apex data paralleled!')

        from torch.nn import MSELoss
        loss_mse = MSELoss()

        student_model.train()
        for step_, batch in enumerate(train_dataloader):
            step += 1
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_masks, lm_label_ids = batch

            if not args.mlm_loss:
                teacher_last_rep, teacher_last_att = teacher_model(
                    input_ids, input_masks)
                teacher_last_att = torch.where(
                    teacher_last_att <= -1e2,
                    torch.zeros_like(teacher_last_att).to(device),
                    teacher_last_att)
                teacher_last_rep.detach()
                teacher_last_att.detach()

            for sample_idx in range(args.sample_times_per_batch):
                att_loss = 0.
                rep_loss = 0.
                rand_seed = int(global_step * args.world_size +
                                sample_idx)  # + args.rank % args.world_size)

                if not args.mlm_loss:
                    if not args.further_train:
                        submodel_config = sample_arch_4_kd(
                            layer_numbers,
                            hidden_sizes,
                            ffn_sizes,
                            qkv_sizes,
                            reset_rand_seed=True,
                            rand_seed=rand_seed)
                    # knowledge distillation
                    student_last_rep, student_last_att = student_model(
                        input_ids, submodel_config, attention_mask=input_masks)
                    student_last_att = torch.where(
                        student_last_att <= -1e2,
                        torch.zeros_like(student_last_att).to(device),
                        student_last_att)

                    att_loss += loss_mse(student_last_att, teacher_last_att)
                    rep_loss += loss_mse(student_last_rep, teacher_last_rep)
                    loss = att_loss + rep_loss

                    if args.gradient_accumulation_steps > 1:
                        rep_loss = rep_loss / args.gradient_accumulation_steps
                        att_loss = att_loss / args.gradient_accumulation_steps
                        loss = loss / args.gradient_accumulation_steps

                    tr_rep_loss += rep_loss.item()
                    tr_att_loss += att_loss.item()
                else:
                    if not args.further_train:
                        submodel_config = sample_arch_4_mlm(
                            layer_numbers,
                            hidden_sizes,
                            ffn_sizes,
                            head_numbers,
                            reset_rand_seed=True,
                            rand_seed=rand_seed)
                    loss = student_model(input_ids,
                                         submodel_config,
                                         attention_mask=input_masks,
                                         masked_lm_labels=lm_label_ids)

                tr_loss += loss.item()
                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward(retain_graph=True)
                else:
                    loss.backward(retain_graph=True)

            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(student_model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                optimizer.zero_grad()
                global_step += 1

                if (step + 1) % (args.gradient_accumulation_steps * args.logging_steps) == 0 \
                        and args.local_rank < 2 or global_step < 100:
                    end_time = time.time()

                    if not args.mlm_loss:
                        logger.info(
                            'Epoch: %s, global_step: %s/%s, lr: %s, loss is %s; '
                            'rep_loss is %s; att_loss is %s; (%.2f sec)' %
                            (epoch, global_step + 1, step_in_each_epoch,
                             optimizer.get_lr()[0],
                             loss.item() * args.gradient_accumulation_steps,
                             rep_loss.item() *
                             args.gradient_accumulation_steps, att_loss.item()
                             * args.gradient_accumulation_steps,
                             end_time - start_time))
                    else:
                        logger.info(
                            'Epoch: %s, global_step: %s/%s, lr: %s, loss is %s; '
                            ' (%.2f sec)' %
                            (epoch, global_step + 1, step_in_each_epoch,
                             optimizer.get_lr()[0],
                             loss.item() * args.gradient_accumulation_steps,
                             end_time - start_time))
                    start_time = time.time()

                if args.logging_steps > 0 and global_step % args.logging_steps == 0 and args.local_rank == 0:
                    tb_writer.add_scalar("lr",
                                         optimizer.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)

                    if not args.mlm_loss:
                        tb_writer.add_scalar("rep_loss",
                                             (tr_rep_loss - rep_logging_loss) /
                                             args.logging_steps, global_step)
                        tb_writer.add_scalar("att_loss",
                                             (tr_att_loss - att_logging_loss) /
                                             args.logging_steps, global_step)
                        rep_logging_loss = tr_rep_loss
                        att_logging_loss = tr_att_loss

                    logging_loss = tr_loss

        # Save a trained model
        if args.rank == 0:
            saving_path = bash_save_dir
            saving_path = Path(os.path.join(saving_path,
                                            "epoch_" + str(epoch)))

            if saving_path.is_dir() and list(saving_path.iterdir()):
                logging.warning(
                    f"Output directory ({ saving_path }) already exists and is not empty!"
                )
            saving_path.mkdir(parents=True, exist_ok=True)

            logging.info("** ** * Saving fine-tuned model ** ** * ")
            model_to_save = student_model.module if hasattr(student_model, 'module')\
                else student_model  # Only save the model it-self

            output_model_file = os.path.join(saving_path, WEIGHTS_NAME)
            output_config_file = os.path.join(saving_path, CONFIG_NAME)

            torch.save(model_to_save.state_dict(), output_model_file)
            model_to_save.config.to_json_file(output_config_file)
            args.tokenizer.save_vocabulary(saving_path)

            torch.save(optimizer.state_dict(),
                       os.path.join(saving_path, "optimizer.pt"))
            logger.info("Saving optimizer and scheduler states to %s",
                        saving_path)

            # debug
            if oncloud:
                local_output_dir = os.path.join(LOCAL_DIR, 'output')
                logger.info(
                    mox.file.list_directory(local_output_dir, recursive=True))
                logger.info('s3_output_dir: ' + args.s3_output_dir)
                mox.file.copy_parallel(local_output_dir, args.s3_output_dir)

    if args.local_rank == 0:
        tb_writer.close()
示例#2
0
class Runner():
    ''' Handler for complete pre-training progress of upstream models '''
    def __init__(self, args, config, dataloader, ckpdir):

        self.device = torch.device('cuda') if (
            args.gpu and torch.cuda.is_available()) else torch.device('cpu')
        if torch.cuda.is_available(): print('[Runner] - CUDA is available!')
        self.model_kept = []
        self.global_step = 1
        self.log = SummaryWriter(ckpdir)

        self.args = args
        self.config = config
        self.dataloader = dataloader
        self.ckpdir = ckpdir

        # optimizer
        self.learning_rate = float(config['optimizer']['learning_rate'])
        self.warmup_proportion = config['optimizer']['warmup_proportion']
        self.gradient_accumulation_steps = config['optimizer'][
            'gradient_accumulation_steps']
        self.gradient_clipping = config['optimizer']['gradient_clipping']

        # Training details
        self.apex = config['runner']['apex']
        self.total_steps = config['runner']['total_steps']
        self.log_step = config['runner']['log_step']
        self.save_step = config['runner']['save_step']
        self.duo_feature = config['runner']['duo_feature']
        self.max_keep = config['runner']['max_keep']

        # model
        self.transformer_config = config['transformer']
        self.input_dim = self.transformer_config['input_dim']
        self.output_dim = 1025 if self.duo_feature else None  # output dim is the same as input dim if not using duo features

    def set_model(self):
        print('[Runner] - Initializing Transformer model...')

        # build the Transformer model with speech prediction head
        model_config = TransformerConfig(self.config)
        self.dr = model_config.downsample_rate
        self.hidden_size = model_config.hidden_size

        self.model = TransformerForMaskedAcousticModel(
            model_config, self.input_dim, self.output_dim).to(self.device)
        self.model.train()

        if self.args.multi_gpu:
            self.model = torch.nn.DataParallel(self.model)
            print('[Runner] - Multi-GPU training Enabled: ' +
                  str(torch.cuda.device_count()))
        print('[Runner] - Number of parameters: ' + str(
            sum(p.numel()
                for p in self.model.parameters() if p.requires_grad)))

        # Setup optimizer
        param_optimizer = list(self.model.named_parameters())

        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]

        if self.apex:
            try:
                from apex.optimizers import FP16_Optimizer
                from apex.optimizers import FusedAdam
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
                )

            optimizer = FusedAdam(optimizer_grouped_parameters,
                                  lr=self.learning_rate,
                                  bias_correction=False,
                                  max_grad_norm=1.0)
            if self.config['optimizer']['loss_scale'] == 0:
                self.optimizer = FP16_Optimizer(optimizer,
                                                dynamic_loss_scale=True)
            else:
                self.optimizer = FP16_Optimizer(
                    optimizer,
                    static_loss_scale=self.config['optimizer']['loss_scale'])
            self.warmup_linear = WarmupLinearSchedule(
                warmup=self.warmup_proportion, t_total=self.total_steps)
        else:
            self.optimizer = BertAdam(optimizer_grouped_parameters,
                                      lr=self.learning_rate,
                                      warmup=self.warmup_proportion,
                                      t_total=self.total_steps)

    def save_model(self, name='states', to_path=None):
        all_states = {
            'SpecHead':
            self.model.SpecHead.state_dict() if not self.args.multi_gpu else
            self.model.module.SpecHead.state_dict(),
            'Transformer':
            self.model.Transformer.state_dict() if not self.args.multi_gpu else
            self.model.module.Transformer.state_dict(),
            'Optimizer':
            self.optimizer.state_dict(),
            'Global_step':
            self.global_step,
            'Settings': {
                'Config': self.config,
                'Paras': self.args,
            },
        }

        if to_path is None:
            new_model_path = '{}/{}-{}.ckpt'.format(self.ckpdir, name,
                                                    self.global_step)
        else:
            new_model_path = to_path

        torch.save(all_states, new_model_path)
        self.model_kept.append(new_model_path)

        if len(self.model_kept) >= self.max_keep:
            os.remove(self.model_kept[0])
            self.model_kept.pop(0)

    def up_sample_frames(self, spec, return_first=False):
        if len(spec.shape) != 3:
            spec = spec.unsqueeze(0)
            assert (len(spec.shape) == 3
                    ), 'Input should have acoustic feature of shape BxTxD'
        # spec shape: [batch_size, sequence_length // downsample_rate, output_dim * downsample_rate]
        spec_flatten = spec.view(spec.shape[0], spec.shape[1] * self.dr,
                                 spec.shape[2] // self.dr)
        if return_first: return spec_flatten[0]
        return spec_flatten  # spec_flatten shape: [batch_size, sequence_length * downsample_rate, output_dim // downsample_rate]

    def down_sample_frames(self, spec):
        left_over = spec.shape[1] % self.dr
        if left_over != 0: spec = spec[:, :-left_over, :]
        spec_stacked = spec.view(spec.shape[0], spec.shape[1] // self.dr,
                                 spec.shape[2] * self.dr)
        return spec_stacked

    def process_data(self, spec):
        """Process training data for the masked acoustic model"""
        with torch.no_grad():

            assert (
                len(spec) == 5
            ), 'dataloader should return (spec_masked, pos_enc, mask_label, attn_mask, spec_stacked)'
            # Unpack and Hack bucket: Bucketing should cause acoustic feature to have shape 1xBxTxD'
            spec_masked = spec[0].squeeze(0)
            pos_enc = spec[1].squeeze(0)
            mask_label = spec[2].squeeze(0)
            attn_mask = spec[3].squeeze(0)
            spec_stacked = spec[4].squeeze(0)

            spec_masked = spec_masked.to(device=self.device)
            if pos_enc.dim() == 3:
                # pos_enc: (batch_size, seq_len, hidden_size)
                # GPU memory need (batch_size * seq_len * hidden_size)
                pos_enc = torch.FloatTensor(pos_enc).to(device=self.device)
            elif pos_enc.dim() == 2:
                # pos_enc: (seq_len, hidden_size)
                # GPU memory only need (seq_len * hidden_size) even after expanded
                pos_enc = torch.FloatTensor(pos_enc).to(
                    device=self.device).expand(spec_masked.size(0),
                                               *pos_enc.size())
            mask_label = torch.ByteTensor(mask_label).to(device=self.device)
            attn_mask = torch.FloatTensor(attn_mask).to(device=self.device)
            spec_stacked = spec_stacked.to(device=self.device)

        return spec_masked, pos_enc, mask_label, attn_mask, spec_stacked  # (x, pos_enc, mask_label, attention_mask. y)

    def train(self):
        ''' Self-Supervised Pre-Training of Transformer Model'''

        pbar = tqdm(total=self.total_steps)
        while self.global_step <= self.total_steps:

            progress = tqdm(self.dataloader, desc="Iteration")

            step = 0
            loss_val = 0
            for batch_is_valid, *batch in progress:
                try:
                    if self.global_step > self.total_steps: break
                    if not batch_is_valid: continue
                    step += 1

                    spec_masked, pos_enc, mask_label, attn_mask, spec_stacked = self.process_data(
                        batch)
                    loss, pred_spec = self.model(spec_masked, pos_enc,
                                                 mask_label, attn_mask,
                                                 spec_stacked)

                    # Accumulate Loss
                    if self.gradient_accumulation_steps > 1:
                        loss = loss / self.gradient_accumulation_steps
                    if self.apex and self.args.multi_gpu:
                        raise NotImplementedError
                    elif self.apex:
                        self.optimizer.backward(loss)
                    elif self.args.multi_gpu:
                        loss = loss.sum()
                        loss.backward()
                    else:
                        loss.backward()
                    loss_val += loss.item()

                    # Update
                    if (step + 1) % self.gradient_accumulation_steps == 0:
                        if self.apex:
                            # modify learning rate with special warm up BERT uses
                            # if conifg.apex is False, BertAdam is used and handles this automatically
                            lr_this_step = self.learning_rate * self.warmup_linear.get_lr(
                                self.global_step, self.warmup_proportion)
                            for param_group in self.optimizer.param_groups:
                                param_group['lr'] = lr_this_step

                        # Step
                        grad_norm = torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.gradient_clipping)
                        if math.isnan(grad_norm):
                            print(
                                '[Runner] - Error : grad norm is NaN @ step ' +
                                str(self.global_step))
                        else:
                            self.optimizer.step()
                        self.optimizer.zero_grad()

                        if self.global_step % self.log_step == 0:
                            # Log
                            self.log.add_scalar('lr',
                                                self.optimizer.get_lr()[0],
                                                self.global_step)
                            self.log.add_scalar('loss', (loss_val),
                                                self.global_step)
                            self.log.add_scalar('gradient norm', grad_norm,
                                                self.global_step)
                            progress.set_description("Loss %.4f" % (loss_val))

                        if self.global_step % self.save_step == 0:
                            self.save_model('states')
                            mask_spec = self.up_sample_frames(
                                spec_masked[0], return_first=True)
                            pred_spec = self.up_sample_frames(
                                pred_spec[0], return_first=True)
                            true_spec = self.up_sample_frames(
                                spec_stacked[0], return_first=True)
                            mask_spec = plot_spectrogram_to_numpy(
                                mask_spec.data.cpu().numpy())
                            pred_spec = plot_spectrogram_to_numpy(
                                pred_spec.data.cpu().numpy())
                            true_spec = plot_spectrogram_to_numpy(
                                true_spec.data.cpu().numpy())
                            self.log.add_image('mask_spec', mask_spec,
                                               self.global_step)
                            self.log.add_image('pred_spec', pred_spec,
                                               self.global_step)
                            self.log.add_image('true_spec', true_spec,
                                               self.global_step)

                        loss_val = 0
                        pbar.update(1)
                        self.global_step += 1

                except RuntimeError as e:
                    if 'CUDA out of memory' in str(e):
                        print('CUDA out of memory at step: ', self.global_step)
                        torch.cuda.empty_cache()
                        self.optimizer.zero_grad()
                    else:
                        raise

        pbar.close()
        self.log.close()
示例#3
0
文件: solver.py 项目: 592595/TERA
class Solver():
    ''' Super class Solver for all kinds of tasks'''
    def __init__(self, config, paras):

        # General Settings
        self.config = config
        self.paras = paras
        self.transformer_config = config['transformer']
        self.device = torch.device('cuda') if (
            self.paras.gpu
            and torch.cuda.is_available()) else torch.device('cpu')
        if torch.cuda.is_available(): self.verbose('CUDA is available!')

        # path and directories
        self.exp_name = paras.name
        if self.exp_name is None:
            self.exp_name = '_'.join([
                paras.config.split('/')[-1].replace('.yaml', ''),
                'sd' + str(paras.seed)
            ])
        self.ckpdir = paras.ckpdir
        self.expdir = os.path.join(self.ckpdir, self.exp_name)

        self.load = paras.load
        # only for test
        self.ckpt = os.path.join(self.ckpdir, paras.ckpt)

        # model
        self.load_model_list = config['solver']['load_model_list']
        self.duo_feature = config['solver']['duo_feature']
        self.output_dim = 1025 if self.duo_feature else None  # output dim is the same as input dim if not using duo features
        if 'input_dim' in self.transformer_config:
            self.input_dim = self.transformer_config['input_dim']
        else:
            raise ValueError(
                'Please update your config file to include the attribute `input_dim`.'
            )

    def verbose(self, msg, end='\n'):
        ''' Verbose function for print information to stdout'''
        if self.paras.verbose:
            print('[SOLVER] - ', msg, end=end)

    def load_data(self, split='train'):
        ''' Load data for training / testing'''
        if split == 'train':
            self.verbose('Loading source data ' +
                         str(self.config['dataloader']['train_set']) +
                         ' from ' + self.config['dataloader']['data_path'])
            if self.duo_feature:
                self.verbose('Loading target data ' +
                             str(self.config['dataloader']['train_set']) +
                             ' from ' +
                             self.config['dataloader']['target_path'])
        elif split == 'test':
            self.verbose('Loading testing data ' +
                         str(self.config['dataloader']['test_set']) +
                         ' from ' + self.config['dataloader']['data_path'])
        else:
            raise NotImplementedError('Invalid `split` argument!')

        if self.duo_feature:
            setattr(self, 'dataloader', get_Dataloader(split, load='duo', use_gpu=self.paras.gpu, \
                    mam_config=self.transformer_config, **self.config['dataloader'])) # run_mam is automatically performed
        else:
            setattr(self, 'dataloader', get_Dataloader(split, load='acoustic', use_gpu=self.paras.gpu, run_mam=True, \
                    mam_config=self.transformer_config, **self.config['dataloader']))

    def set_model(self,
                  inference=False,
                  with_head=False,
                  from_path=None,
                  output_attention=False):
        self.verbose('Initializing Transformer model.')

        # uild the Transformer model with speech prediction head
        self.model_config = TransformerConfig(self.config)
        self.dr = self.model_config.downsample_rate
        self.hidden_size = self.model_config.hidden_size
        self.with_head = with_head
        self.output_attention = output_attention

        if not inference or with_head:
            self.model = TransformerForMaskedAcousticModel(
                self.model_config, self.input_dim, self.output_dim,
                self.output_attention).to(self.device)
            self.transformer = self.model.Transformer
            if self.paras.multi_gpu:
                self.model = torch.nn.DataParallel(self.model)
                self.transformer = torch.nn.DataParallel(self.transformer)
                self.verbose('Multi-GPU training Enabled: ' +
                             str(torch.cuda.device_count()))
            self.verbose('Number of parameters: ' + str(
                sum(p.numel()
                    for p in self.model.parameters() if p.requires_grad)))

        if inference and not with_head:
            self.transformer = TransformerModel(
                self.model_config, self.input_dim,
                self.output_attention).to(self.device)
            if self.paras.multi_gpu:
                self.transformer = torch.nn.DataParallel(self.transformer)
                self.verbose('Multi-GPU training Enabled: ' +
                             str(torch.cuda.device_count()))
            self.verbose('Number of parameters: ' + str(
                sum(p.numel() for p in self.transformer.parameters()
                    if p.requires_grad)))
            self.transformer.eval()
        elif inference and with_head:
            self.model.eval()
        elif not inference:
            self.model.train()

            # Setup optimizer
            param_optimizer = list(self.model.named_parameters())

            no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
            optimizer_grouped_parameters = [{
                'params': [
                    p for n, p in param_optimizer
                    if not any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.01
            }, {
                'params': [
                    p for n, p in param_optimizer
                    if any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.0
            }]

            if self.apex:
                try:
                    from apex.optimizers import FP16_Optimizer
                    from apex.optimizers import FusedAdam
                except ImportError:
                    raise ImportError(
                        "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
                    )

                optimizer = FusedAdam(optimizer_grouped_parameters,
                                      lr=self.learning_rate,
                                      bias_correction=False,
                                      max_grad_norm=1.0)
                if self.config['optimizer']['loss_scale'] == 0:
                    self.optimizer = FP16_Optimizer(optimizer,
                                                    dynamic_loss_scale=True)
                else:
                    self.optimizer = FP16_Optimizer(
                        optimizer,
                        static_loss_scale=self.config['optimizer']
                        ['loss_scale'])
                self.warmup_linear = WarmupLinearSchedule(
                    warmup=self.warmup_proportion, t_total=self.total_steps)
            else:
                self.optimizer = BertAdam(optimizer_grouped_parameters,
                                          lr=self.learning_rate,
                                          warmup=self.warmup_proportion,
                                          t_total=self.total_steps)
        else:
            raise NotImplementedError('Invalid Arguments!')

        if self.load:  # This will be set to True by default when Tester is running set_model()
            self.load_model(inference=inference,
                            with_head=with_head,
                            from_path=from_path)

    def save_model(self, name='states', model_all=True, to_path=None):
        if model_all:
            all_states = {
                'SpecHead':
                self.model.SpecHead.state_dict() if not self.paras.multi_gpu
                else self.model.module.SpecHead.state_dict(),
                'Transformer':
                self.transformer.state_dict() if not self.paras.multi_gpu else
                self.transformer.module.state_dict(),
                'Optimizer':
                self.optimizer.state_dict(),
                'Global_step':
                self.global_step,
                'Settings': {
                    'Config': self.config,
                    'Paras': self.paras,
                },
            }
        else:
            all_states = {
                'Transformer':
                self.transformer.state_dict() if not self.paras.multi_gpu else
                self.transformer.module.state_dict(),
                'Settings': {
                    'Config': self.config,
                    'Paras': self.paras,
                },
            }
        if to_path is None:
            new_model_path = '{}/{}-{}.ckpt'.format(self.expdir, name,
                                                    self.global_step)
        else:
            new_model_path = to_path
        torch.save(all_states, new_model_path)
        self.model_kept.append(new_model_path)

        if len(self.model_kept) >= self.max_keep:
            os.remove(self.model_kept[0])
            self.model_kept.pop(0)

    def load_model(self, inference=False, with_head=False, from_path=None):
        if from_path is not None:
            self.verbose('Load model from {}'.format(from_path))
            all_states = torch.load(from_path, map_location='cpu')
            self.load_model_list = ['Transformer']
        else:
            self.verbose('Load model from {}'.format(self.ckpt))
            all_states = torch.load(self.ckpt, map_location='cpu')

        if 'SpecHead' in self.load_model_list:
            if not inference or with_head:
                try:
                    if not self.paras.multi_gpu:
                        self.model.SpecHead.load_state_dict(
                            all_states['SpecHead'])
                    else:
                        self.model.module.SpecHead.load_state_dict(
                            all_states['SpecHead'])
                    self.verbose('[SpecHead] - Loaded')
                except:
                    self.verbose('[SpecHead - X]')

        if 'Transformer' in self.load_model_list:
            try:
                state_dict = all_states['Transformer']

                # Load from a PyTorch state_dict
                old_keys = []
                new_keys = []
                for key in state_dict.keys():
                    new_key = None
                    if 'gamma' in key:
                        new_key = key.replace('gamma', 'weight')
                    if 'beta' in key:
                        new_key = key.replace('beta', 'bias')
                    if new_key:
                        old_keys.append(key)
                        new_keys.append(new_key)
                for old_key, new_key in zip(old_keys, new_keys):
                    state_dict[new_key] = state_dict.pop(old_key)

                missing_keys = []
                unexpected_keys = []
                error_msgs = []
                # copy state_dict so _load_from_state_dict can modify it
                metadata = getattr(state_dict, '_metadata', None)
                state_dict = state_dict.copy()
                if metadata is not None:
                    state_dict._metadata = metadata

                def load(module, prefix=''):
                    local_metadata = {} if metadata is None else metadata.get(
                        prefix[:-1], {})
                    module._load_from_state_dict(state_dict, prefix,
                                                 local_metadata, True,
                                                 missing_keys, unexpected_keys,
                                                 error_msgs)
                    for name, child in module._modules.items():
                        if child is not None:
                            load(child, prefix + name + '.')

                # perform load
                if not self.paras.multi_gpu:
                    load(self.transformer)
                else:
                    load(self.transformer.module)

                if len(missing_keys) > 0:
                    self.verbose(
                        "Weights of {} not initialized from pretrained model: {}"
                        .format(self.transformer.__class__.__name__,
                                missing_keys))
                if len(unexpected_keys) > 0:
                    self.verbose(
                        "Weights from pretrained model not used in {}: {}".
                        format(self.transformer.__class__.__name__,
                               unexpected_keys))
                if len(error_msgs) > 0:
                    raise RuntimeError(
                        'Error(s) in loading state_dict for {}:\n\t{}'.format(
                            self.transformer.__class__.__name__,
                            "\n\t".join(error_msgs)))
                self.verbose('[Transformer] - Loaded')
            except:
                self.verbose('[Transformer - X]')

        if 'Optimizer' in self.load_model_list and not inference:
            try:
                self.optimizer.load_state_dict(all_states['Optimizer'])
                for state in self.optimizer.state.values():
                    for k, v in state.items():
                        if torch.is_tensor(v):
                            state[k] = v.cuda()
                self.verbose('[Optimizer] - Loaded')
            except:
                self.verbose('[Optimizer - X]')

        if 'Global_step' in self.load_model_list and not inference:
            try:
                self.global_step = all_states['Global_step']
                self.verbose('[Global_step] - Loaded')
            except:
                self.verbose('[Global_step - X]')

        self.verbose('Model loading complete!')

    def up_sample_frames(self, spec, return_first=False):
        if len(spec.shape) != 3:
            spec = spec.unsqueeze(0)
            assert (len(spec.shape) == 3
                    ), 'Input should have acoustic feature of shape BxTxD'
        # spec shape: [batch_size, sequence_length // downsample_rate, output_dim * downsample_rate]
        spec_flatten = spec.view(spec.shape[0], spec.shape[1] * self.dr,
                                 spec.shape[2] // self.dr)
        if return_first: return spec_flatten[0]
        return spec_flatten  # spec_flatten shape: [batch_size, sequence_length * downsample_rate, output_dim // downsample_rate]

    def down_sample_frames(self, spec):
        left_over = spec.shape[1] % self.dr
        if left_over != 0: spec = spec[:, :-left_over, :]
        spec_stacked = spec.view(spec.shape[0], spec.shape[1] // self.dr,
                                 spec.shape[2] * self.dr)
        return spec_stacked

    def position_encoding(self, seq_len, batch_size=None, padding_idx=None):
        ''' Sinusoid position encoding table '''
        def cal_angle(position, hid_idx):
            return position / np.power(10000, 2 *
                                       (hid_idx // 2) / self.hidden_size)

        def get_posi_angle_vec(position):
            return [
                cal_angle(position, hid_j) for hid_j in range(self.hidden_size)
            ]

        sinusoid_table = np.array(
            [get_posi_angle_vec(pos_i) for pos_i in range(seq_len)])

        sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
        sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1

        if padding_idx is not None:
            sinusoid_table[
                padding_idx:] = 0.  # zero vector for padding dimension

        if batch_size is not None:
            batch_sinusoid_table = np.repeat(sinusoid_table[np.newaxis, ...],
                                             batch_size,
                                             axis=0)
            return batch_sinusoid_table  # (batch_size, seq_len, hidden_size)
        else:
            return sinusoid_table  # (seq_len, hidden_size)