Exemplo n.º 1
0
    def setUp(self):
        model = ModelInputDict()

        sp_net_config = supernet(expand_ratio=[0.5, 1.0])
        self.model = Convert(sp_net_config).convert(model)
        self.images = paddle.randn(shape=[2, 3, 32, 32], dtype='float32')
        self.images2 = {
            'data': paddle.randn(shape=[2, 12, 32, 32], dtype='float32')
        }
        default_run_config = {'skip_layers': ['conv1.0', 'conv2.0']}
        self.run_config = RunConfig(**default_run_config)

        self.ofa_model = OFA(self.model, run_config=self.run_config)
        self.ofa_model._clear_search_space(self.images, data=self.images2)
Exemplo n.º 2
0
 def init_config(self):
     default_run_config = {
         'train_batch_size': 1,
         'n_epochs': [[2, 5]],
         'init_learning_rate': [[0.003, 0.001]],
         'dynamic_batch_size': [1],
         'total_images': 1,
     }
     self.run_config = RunConfig(**default_run_config)
     default_distill_config = {
         'teacher_model': self.teacher_model,
         'mapping_layers': ['models.3.fn'],
     }
     self.distill_config = DistillConfig(**default_distill_config)
     self.elastic_order = None
Exemplo n.º 3
0
    def init_config(self):
        default_run_config = {
            'train_batch_size': 1,
            'eval_batch_size': 1,
            'n_epochs': [[2, 5]],
            'init_learning_rate': [[0.003, 0.001]],
            'dynamic_batch_size': [1],
            'total_images': 1,
        }
        self.run_config = RunConfig(**default_run_config)

        default_distill_config = {
            'lambda_distill': 0.01,
            'teacher_model': self.teacher_model,
        }
        self.distill_config = DistillConfig(**default_distill_config)
Exemplo n.º 4
0
    def init_config(self):
        default_run_config = {
            'train_batch_size': 1,
            'eval_batch_size': 1,
            'n_epochs': [[1], [2, 3], [4, 5]],
            'init_learning_rate': [[0.001], [0.003, 0.001], [0.003, 0.001]],
            'dynamic_batch_size': [1, 1, 1],
            'total_images': 1,
            'elastic_depth': (5, 15, 24)
        }
        self.run_config = RunConfig(**default_run_config)

        default_distill_config = {
            'lambda_distill': 0.01,
            'teacher_model': self.teacher_model,
            'mapping_layers': ['models.0.fn']
        }
        self.distill_config = DistillConfig(**default_distill_config)
Exemplo n.º 5
0
def test_ofa():

    model = Model()
    teacher_model = Model()

    default_run_config = {
        'train_batch_size': 256,
        'n_epochs': [[1], [2, 3], [4, 5]],
        'init_learning_rate': [[0.001], [0.003, 0.001], [0.003, 0.001]],
        'dynamic_batch_size': [1, 1, 1],
        'total_images': 50000,  #1281167,
        'elastic_depth': (2, 5, 8)
    }
    run_config = RunConfig(**default_run_config)

    default_distill_config = {
        'lambda_distill': 0.01,
        'teacher_model': teacher_model,
        'mapping_layers': ['models.0.fn']
    }
    distill_config = DistillConfig(**default_distill_config)

    ofa_model = OFA(model, run_config, distill_config=distill_config)

    train_dataset = paddle.vision.datasets.MNIST(mode='train',
                                                 backend='cv2',
                                                 transform=transform)
    train_loader = paddle.io.DataLoader(train_dataset,
                                        places=place,
                                        feed_list=[image, label],
                                        drop_last=True,
                                        batch_size=64)

    start_epoch = 0
    for idx in range(len(run_config.n_epochs)):
        cur_idx = run_config.n_epochs[idx]
        for ph_idx in range(len(cur_idx)):
            cur_lr = run_config.init_learning_rate[idx][ph_idx]
            adam = paddle.optimizer.Adam(
                learning_rate=cur_lr,
                parameter_list=(ofa_model.parameters() +
                                ofa_model.netAs_param))
            for epoch_id in range(start_epoch,
                                  run_config.n_epochs[idx][ph_idx]):
                for batch_id, data in enumerate(train_loader()):
                    dy_x_data = np.array([
                        x[0].reshape(1, 28, 28) for x in data
                    ]).astype('float32')
                    y_data = np.array([x[1] for x in data
                                       ]).astype('int64').reshape(-1, 1)

                    img = paddle.dygraph.to_variable(dy_x_data)
                    label = paddle.dygraph.to_variable(y_data)
                    label.stop_gradient = True

                    for model_no in range(run_config.dynamic_batch_size[idx]):
                        output, _ = ofa_model(img, label)
                        loss = F.mean(output)
                        dis_loss = ofa_model.calc_distill_loss()
                        loss += dis_loss
                        loss.backward()

                        if batch_id % 10 == 0:
                            print(
                                'epoch: {}, batch: {}, loss: {}, distill loss: {}'
                                .format(epoch_id, batch_id,
                                        loss.numpy()[0],
                                        dis_loss.numpy()[0]))
                    ### accumurate dynamic_batch_size network of gradients for same batch of data
                    ### NOTE: need to fix gradients accumulate in PaddlePaddle
                    adam.minimize(loss)
                    adam.clear_gradients()
            start_epoch = run_config.n_epochs[idx][ph_idx]
Exemplo n.º 6
0
        sp_config = supernet(expand_ratio=args.width_mult_list)
        model = Convert(sp_config).convert(model)
        utils.set_state_dict(model, origin_weights)
        del origin_weights

        teacher_model = ErnieModelForSequenceClassification.from_pretrained(
            args.from_pretrained, num_labels=3, name='teacher')
        setattr(teacher_model, 'return_additional_info', True)

        default_run_config = {
            'n_epochs': [[4 * args.epoch], [6 * args.epoch]],
            'init_learning_rate': [[args.lr], [args.lr]],
            'elastic_depth': args.depth_mult_list,
            'dynamic_batch_size': [[1, 1], [1, 1]]
        }
        run_config = RunConfig(**default_run_config)

        model_cfg = get_config(args.from_pretrained)

        default_distill_config = {'teacher_model': teacher_model}
        distill_config = DistillConfig(**default_distill_config)

        ofa_model = OFA(model,
                        run_config,
                        distill_config=distill_config,
                        elastic_order=['width', 'depth'])

        ### suppose elastic width first
        if args.reorder_weight:
            head_importance, neuron_importance = compute_neuron_head_importance(
                args, ofa_model.model, tokenizer, dev_ds, place, model_cfg)
Exemplo n.º 7
0
def do_train(args):
    paddle.set_device("gpu" if args.n_gpu else "cpu")
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    set_seed(args)

    args.task_name = args.task_name.lower()
    dataset_class, metric_class = TASK_CLASSES[args.task_name]
    args.model_type = args.model_type.lower()
    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]

    train_ds = dataset_class.get_datasets(['train'])

    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
    trans_func = partial(convert_example,
                         tokenizer=tokenizer,
                         label_list=train_ds.get_labels(),
                         max_seq_length=args.max_seq_length)
    train_ds = train_ds.apply(trans_func, lazy=True)
    train_batch_sampler = paddle.io.DistributedBatchSampler(
        train_ds, batch_size=args.batch_size, shuffle=True)
    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # input
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # segment
        Stack(),  # length
        Stack(dtype="int64" if train_ds.get_labels() else "float32")  # label
    ): [data for i, data in enumerate(fn(samples)) if i != 2]
    train_data_loader = DataLoader(dataset=train_ds,
                                   batch_sampler=train_batch_sampler,
                                   collate_fn=batchify_fn,
                                   num_workers=0,
                                   return_list=True)
    if args.task_name == "mnli":
        dev_dataset_matched, dev_dataset_mismatched = dataset_class.get_datasets(
            ["dev_matched", "dev_mismatched"])
        dev_dataset_matched = dev_dataset_matched.apply(trans_func, lazy=True)
        dev_dataset_mismatched = dev_dataset_mismatched.apply(trans_func,
                                                              lazy=True)
        dev_batch_sampler_matched = paddle.io.BatchSampler(
            dev_dataset_matched, batch_size=args.batch_size, shuffle=False)
        dev_data_loader_matched = DataLoader(
            dataset=dev_dataset_matched,
            batch_sampler=dev_batch_sampler_matched,
            collate_fn=batchify_fn,
            num_workers=0,
            return_list=True)
        dev_batch_sampler_mismatched = paddle.io.BatchSampler(
            dev_dataset_mismatched, batch_size=args.batch_size, shuffle=False)
        dev_data_loader_mismatched = DataLoader(
            dataset=dev_dataset_mismatched,
            batch_sampler=dev_batch_sampler_mismatched,
            collate_fn=batchify_fn,
            num_workers=0,
            return_list=True)
    else:
        dev_dataset = dataset_class.get_datasets(["dev"])
        dev_dataset = dev_dataset.apply(trans_func, lazy=True)
        dev_batch_sampler = paddle.io.BatchSampler(dev_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=False)
        dev_data_loader = DataLoader(dataset=dev_dataset,
                                     batch_sampler=dev_batch_sampler,
                                     collate_fn=batchify_fn,
                                     num_workers=0,
                                     return_list=True)

    num_labels = 1 if train_ds.get_labels() == None else len(
        train_ds.get_labels())

    # Step1: Initialize the origin BERT model.
    model = model_class.from_pretrained(args.model_name_or_path,
                                        num_classes=num_labels)
    if paddle.distributed.get_world_size() > 1:
        model = paddle.DataParallel(model)

    # Step2: Convert origin model to supernet.
    sp_config = supernet(expand_ratio=args.width_mult_list)
    model = Convert(sp_config).convert(model)

    # Use weights saved in the dictionary to initialize supernet.
    weights_path = os.path.join(args.model_name_or_path,
                                'model_state.pdparams')
    origin_weights = paddle.load(weights_path)
    model.set_state_dict(origin_weights)

    # Step3: Define teacher model.
    teacher_model = model_class.from_pretrained(args.model_name_or_path,
                                                num_classes=num_labels)
    new_dict = utils.utils.remove_model_fn(teacher_model, origin_weights)
    teacher_model.set_state_dict(new_dict)
    del origin_weights, new_dict

    default_run_config = {'elastic_depth': args.depth_mult_list}
    run_config = RunConfig(**default_run_config)

    # Step4: Config about distillation.
    mapping_layers = ['bert.embeddings']
    for idx in range(model.bert.config['num_hidden_layers']):
        mapping_layers.append('bert.encoder.layers.{}'.format(idx))

    default_distill_config = {
        'lambda_distill': args.lambda_rep,
        'teacher_model': teacher_model,
        'mapping_layers': mapping_layers,
    }
    distill_config = DistillConfig(**default_distill_config)

    # Step5: Config in supernet training.
    ofa_model = OFA(model,
                    run_config=run_config,
                    distill_config=distill_config,
                    elastic_order=['depth'])
    #elastic_order=['width'])

    criterion = paddle.nn.loss.CrossEntropyLoss() if train_ds.get_labels(
    ) else paddle.nn.loss.MSELoss()

    metric = metric_class()

    if args.task_name == "mnli":
        dev_data_loader = (dev_data_loader_matched, dev_data_loader_mismatched)

    lr_scheduler = paddle.optimizer.lr.LambdaDecay(
        args.learning_rate,
        lambda current_step, num_warmup_steps=args.warmup_steps,
        num_training_steps=args.max_steps if args.max_steps > 0 else
        (len(train_data_loader) * args.num_train_epochs): float(
            current_step) / float(max(1, num_warmup_steps))
        if current_step < num_warmup_steps else max(
            0.0,
            float(num_training_steps - current_step) / float(
                max(1, num_training_steps - num_warmup_steps))))

    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        epsilon=args.adam_epsilon,
        parameters=ofa_model.model.parameters(),
        weight_decay=args.weight_decay,
        apply_decay_param_fun=lambda x: x in [
            p.name for n, p in ofa_model.model.named_parameters()
            if not any(nd in n for nd in ["bias", "norm"])
        ])

    global_step = 0
    tic_train = time.time()
    for epoch in range(args.num_train_epochs):
        # Step6: Set current epoch and task.
        ofa_model.set_epoch(epoch)
        ofa_model.set_task('depth')

        for step, batch in enumerate(train_data_loader):
            global_step += 1
            input_ids, segment_ids, labels = batch

            for depth_mult in args.depth_mult_list:
                for width_mult in args.width_mult_list:
                    # Step7: Broadcast supernet config from width_mult,
                    # and use this config in supernet training.
                    net_config = utils.dynabert_config(ofa_model, width_mult,
                                                       depth_mult)
                    ofa_model.set_net_config(net_config)
                    logits, teacher_logits = ofa_model(
                        input_ids, segment_ids, attention_mask=[None, None])
                    rep_loss = ofa_model.calc_distill_loss()
                    if args.task_name == 'sts-b':
                        logit_loss = 0.0
                    else:
                        logit_loss = soft_cross_entropy(
                            logits, teacher_logits.detach())
                    loss = rep_loss + args.lambda_logit * logit_loss
                    loss.backward()
            optimizer.step()
            lr_scheduler.step()
            ofa_model.model.clear_gradients()

            if global_step % args.logging_steps == 0:
                if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0:
                    logger.info(
                        "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s"
                        % (global_step, epoch, step, loss, args.logging_steps /
                           (time.time() - tic_train)))
                tic_train = time.time()

            if global_step % args.save_steps == 0:
                if args.task_name == "mnli":
                    evaluate(teacher_model,
                             criterion,
                             metric,
                             dev_data_loader_matched,
                             width_mult=100)
                    evaluate(teacher_model,
                             criterion,
                             metric,
                             dev_data_loader_mismatched,
                             width_mult=100)
                else:
                    evaluate(teacher_model,
                             criterion,
                             metric,
                             dev_data_loader,
                             width_mult=100)
                for depth_mult in args.depth_mult_list:
                    for width_mult in args.width_mult_list:
                        net_config = utils.dynabert_config(
                            ofa_model, width_mult, depth_mult)
                        ofa_model.set_net_config(net_config)
                        tic_eval = time.time()
                        if args.task_name == "mnli":
                            acc = evaluate(ofa_model, criterion, metric,
                                           dev_data_loader_matched, width_mult,
                                           depth_mult)
                            evaluate(ofa_model, criterion, metric,
                                     dev_data_loader_mismatched, width_mult,
                                     depth_mult)
                            print("eval done total : %s s" %
                                  (time.time() - tic_eval))
                        else:
                            acc = evaluate(ofa_model, criterion, metric,
                                           dev_data_loader, width_mult,
                                           depth_mult)
                            print("eval done total : %s s" %
                                  (time.time() - tic_eval))

                        if (not args.n_gpu > 1
                            ) or paddle.distributed.get_rank() == 0:
                            output_dir = os.path.join(args.output_dir,
                                                      "model_%d" % global_step)
                            if not os.path.exists(output_dir):
                                os.makedirs(output_dir)
                            # need better way to get inner model of DataParallel
                            model_to_save = model._layers if isinstance(
                                model, paddle.DataParallel) else model
                            model_to_save.save_pretrained(output_dir)
                            tokenizer.save_pretrained(output_dir)
Exemplo n.º 8
0
 def init_config(self):
     default_run_config = {'skip_layers': ['branch2.0']}
     self.run_config = RunConfig(**default_run_config)
Exemplo n.º 9
0
    def __call__(self, model, param_state_dict):

        paddleslim = try_import('paddleslim')
        from paddleslim.nas.ofa import OFA, RunConfig, utils
        from paddleslim.nas.ofa.convert_super import Convert, supernet
        task = self.ofa_config['task']
        expand_ratio = self.ofa_config['expand_ratio']

        skip_neck = self.ofa_config['skip_neck']
        skip_head = self.ofa_config['skip_head']

        run_config = self.ofa_config['RunConfig']
        if 'skip_layers' in run_config:
            skip_layers = run_config['skip_layers']
        else:
            skip_layers = []

        # supernet config
        sp_config = supernet(expand_ratio=expand_ratio)
        # convert to supernet
        model = Convert(sp_config).convert(model)

        skip_names = []
        if skip_neck:
            skip_names.append('neck.')
        if skip_head:
            skip_names.append('head.')

        for name, sublayer in model.named_sublayers():
            for n in skip_names:
                if n in name:
                    skip_layers.append(name)

        run_config['skip_layers'] = skip_layers
        run_config = RunConfig(**run_config)

        # build ofa model
        ofa_model = OFA(model, run_config=run_config)

        ofa_model.set_epoch(0)
        ofa_model.set_task(task)

        input_spec = [{
            "image": paddle.ones(
                shape=[1, 3, 640, 640], dtype='float32'),
            "im_shape": paddle.full(
                [1, 2], 640, dtype='float32'),
            "scale_factor": paddle.ones(
                shape=[1, 2], dtype='float32')
        }]

        ofa_model._clear_search_space(input_spec=input_spec)
        ofa_model._build_ss = True
        check_ss = ofa_model._sample_config('expand_ratio', phase=None)
        # tokenize the search space
        ofa_model.tokenize()
        # check token map, search cands and search space
        logger.info('Token map is {}'.format(ofa_model.token_map))
        logger.info('Search candidates is {}'.format(ofa_model.search_cands))
        logger.info('The length of search_space is {}, search_space is {}'.
                    format(len(ofa_model._ofa_layers), ofa_model._ofa_layers))
        # set model state dict into ofa model
        utils.set_state_dict(ofa_model.model, param_state_dict)
        return ofa_model