Esempio n. 1
0
def init_dataloader(config):
    trainloader, testloader = get_dataloader(
        dataset=config.dataset,
        train_batch_size=config.batch_size,
        test_batch_size=256,
        returnset=config.data_distributed)
    return trainloader, testloader
Esempio n. 2
0
def main(config, args):
    # init logger
    classes = {
        'cifar10': 10,
        'cifar100': 100,
        'mnist': 10,
        'tiny_imagenet': 200
    }
    logger, writer = init_logger(config, args)
    best_acc_vec = []
    test_acc_vec_vec = []

    for n_runs in range(1):
        if args.sigma_w2 != None and n_runs != 0:
            break

        # build model
        model = get_network(config.network,
                            config.depth,
                            config.dataset,
                            use_bn=config.get('use_bn', args.bn),
                            scaled=args.scaled_init,
                            act=args.act)
        mask = None
        mb = ModelBase(config.network, config.depth, config.dataset, model)
        mb.cuda()
        if mask is not None:
            mb.register_mask(mask)
            ratio_vec_ = print_mask_information(mb, logger)

        # preprocessing
        # ====================================== get dataloader ======================================
        trainloader, testloader = get_dataloader(config.dataset,
                                                 config.batch_size, 256, 4)
        # ====================================== fetch configs ======================================
        ckpt_path = config.checkpoint_dir
        num_iterations = config.iterations
        if args.target_ratio == None:
            target_ratio = config.target_ratio
        else:
            target_ratio = args.target_ratio

        normalize = config.normalize
        # ====================================== fetch exception ======================================
        exception = get_exception_layers(
            mb.model, str_to_list(config.exception, ',', int))
        logger.info('Exception: ')

        for idx, m in enumerate(exception):
            logger.info('  (%d) %s' % (idx, m))

        # ====================================== fetch training schemes ======================================
        ratio = 1 - (1 - target_ratio)**(1.0 / num_iterations)
        learning_rates = str_to_list(config.learning_rate, ',', float)
        weight_decays = str_to_list(config.weight_decay, ',', float)
        training_epochs = str_to_list(config.epoch, ',', int)
        logger.info(
            'Normalize: %s, Total iteration: %d, Target ratio: %.2f, Iter ratio %.4f.'
            % (normalize, num_iterations, target_ratio, ratio))
        logger.info('Basic Settings: ')
        for idx in range(len(learning_rates)):
            logger.info('  %d: LR: %.5f, WD: %.5f, Epochs: %d' %
                        (idx, learning_rates[idx], weight_decays[idx],
                         training_epochs[idx]))

        # ====================================== start pruning ======================================
        iteration = 0
        for _ in range(1):
            logger.info(
                '** Target ratio: %.4f, iter ratio: %.4f, iteration: %d/%d.' %
                (target_ratio, ratio, 1, num_iterations))

            # mb.model.apply(weights_init)
            print('#' * 40)
            print('USING {} INIT SCHEME'.format(args.init))
            print('#' * 40)
            if args.init == 'kaiming_xavier':
                mb.model.apply(weights_init_kaiming_xavier)
            elif args.init == 'kaiming':
                if args.act == 'relu' or args.act == 'elu':
                    mb.model.apply(weights_init_kaiming_relu)
                elif args.act == 'tanh':
                    mb.model.apply(weights_init_kaiming_tanh)
            elif args.init == 'xavier':
                mb.model.apply(weights_init_xavier)
            elif args.init == 'EOC':
                mb.model.apply(weights_init_EOC)
            elif args.init == 'ordered':

                def weights_init_ord(m):
                    if isinstance(m, nn.Conv2d):
                        ord_weights(m.weight, sigma_w2=args.sigma_w2)
                        if m.bias is not None:
                            ord_bias(m.bias)
                    elif isinstance(m, nn.Linear):
                        ord_weights(m.weight, sigma_w2=args.sigma_w2)
                        if m.bias is not None:
                            ord_bias(m.bias)
                    elif isinstance(m, nn.BatchNorm2d):
                        # Note that BN's running_var/mean are
                        # already initialized to 1 and 0 respectively.
                        if m.weight is not None:
                            m.weight.data.fill_(1.0)
                        if m.bias is not None:
                            m.bias.data.zero_()

                mb.model.apply(weights_init_ord)
            else:
                raise NotImplementedError

            print("=> Applying weight initialization(%s)." %
                  config.get('init_method', 'kaiming'))
            print("Iteration of: %d/%d" % (iteration, num_iterations))

            if config.pruner == 'SNIP':
                print('=> Using SNIP')
                masks, scaled_masks = SNIP(
                    mb.model,
                    ratio,
                    trainloader,
                    'cuda',
                    num_classes=classes[config.dataset],
                    samples_per_class=config.samples_per_class,
                    num_iters=config.get('num_iters', 1),
                    scaled_init=args.scaled_init)
            elif config.pruner == 'GraSP':
                print('=> Using GraSP')
                masks, scaled_masks = GraSP(
                    mb.model,
                    ratio,
                    trainloader,
                    'cuda',
                    num_classes=classes[config.dataset],
                    samples_per_class=config.samples_per_class,
                    num_iters=config.get('num_iters', 1),
                    scaled_init=args.scaled_init)
            iteration = 0

            ################################################################################
            _masks = None
            _masks_scaled = None
            if not args.bn:
                # build model that has the same weights as the pruned network but with BN now !
                model2 = get_network(config.network,
                                     config.depth,
                                     config.dataset,
                                     use_bn=config.get('use_bn', True),
                                     scaled=args.scaled_init,
                                     act=args.act)
                weights_temp = []
                for layer_old in mb.model.modules():
                    if isinstance(layer_old, nn.Conv2d) or isinstance(
                            layer_old, nn.Linear):
                        weights_temp.append(layer_old.weight)
                idx = 0
                for layer_new in model2.modules():
                    if isinstance(layer_new, nn.Conv2d) or isinstance(
                            layer_new, nn.Linear):
                        layer_new.weight.data = weights_temp[idx]
                        idx += 1

                # Creating a base model with BN included now
                mb = ModelBase(config.network, config.depth, config.dataset,
                               model2)
                mb.cuda()

                _masks = dict()
                _masks_scaled = dict()
                layer_keys_new = []
                for layer in (mb.model.modules()):
                    if isinstance(layer, nn.Conv2d) or isinstance(
                            layer, nn.Linear):
                        layer_keys_new.append(layer)

                for new_keys, old_keys in zip(layer_keys_new, masks.keys()):
                    _masks[new_keys] = masks[old_keys]
                    if args.scaled_init:
                        _masks_scaled[new_keys] = scaled_masks[old_keys]
            ################################################################################

            if _masks == None:
                _masks = masks
                _masks_scaled = scaled_masks

            # ========== register mask ==================
            mb.register_mask(_masks)

            ## ========== debugging ==================

            if args.scaled_init:
                if config.network == 'vgg':
                    print('scaling VGG')
                    mb.scaling_weights(_masks_scaled)

            # ========== save pruned network ============
            logger.info('Saving..')
            state = {
                'net': mb.model,
                'acc': -1,
                'epoch': -1,
                'args': config,
                'mask': mb.masks,
                'ratio': mb.get_ratio_at_each_layer()
            }
            path = os.path.join(
                ckpt_path, 'prune_%s_%s%s_r%s_it%d.pth.tar' %
                (config.dataset, config.network, config.depth, target_ratio,
                 iteration))
            torch.save(state, path)

            # ========== print pruning details ============
            logger.info('**[%d] Mask and training setting: ' % iteration)
            ratio_vec_ = print_mask_information(mb, logger)
            logger.info('  LR: %.5f, WD: %.5f, Epochs: %d' %
                        (learning_rates[iteration], weight_decays[iteration],
                         training_epochs[iteration]))

            results_path = config.summary_dir + args.init + '_sp' + str(
                args.target_ratio).replace('.', '_')
            if args.scaled_init:
                results_path += '_scaled'
            if args.bn:
                results_path += '_bn'

            if args.sigma_w2 != None and args.init == 'ordered':
                results_path += '_sgw2{}'.format(args.sigma_w2).replace(
                    '.', '_')

            results_path += '_' + args.act + '_' + str(config.depth)
            print('saving the ratios')
            print(results_path)
            if not os.path.isdir(results_path): os.mkdir(results_path)
            np.save(results_path + '/ratios_pruned{}'.format(args.seed_tiny),
                    np.array(ratio_vec_))

            # if args.sigma_w2 != None:
            # 	break
            # ========== finetuning =======================
            best_acc, test_acc_vec = train_once(
                mb=mb,
                net=mb.model,
                trainloader=trainloader,
                testloader=testloader,
                writer=writer,
                config=config,
                ckpt_path=ckpt_path,
                learning_rate=learning_rates[iteration],
                weight_decay=weight_decays[iteration],
                num_epochs=training_epochs[iteration],
                iteration=iteration,
                logger=logger,
                args=args)

            best_acc_vec.append(best_acc)
            test_acc_vec_vec.append(test_acc_vec)

            np.save(results_path + '/best_acc{}'.format(args.seed_tiny),
                    np.array(best_acc_vec))
            np.save(results_path + '/test_acc{}'.format(args.seed_tiny),
                    np.array(test_acc_vec_vec))
Esempio n. 3
0
}

act = torch.nn.Sigmoid() if args.activation is None else act_dict[
    args.activation]

# init model
encoder_sizes = [28 * 28, 1000, 500, 250, 30]
decoder_sizes = [30, 250, 500, 1000, 28 * 28]

net = deep_autoencoder(encoder_sizes=encoder_sizes,
                       decoder_sizes=decoder_sizes,
                       activation=act).to(args.device)

# init dataloader
trainloader, testloader = get_dataloader(dataset=args.dataset,
                                         train_batch_size=args.batch_size,
                                         test_batch_size=256)
# init optimizer
optim_name = args.optimizer.lower()
tag = optim_name
optimizer = get_optimizer(optim_name, net, args)

# init lr scheduler
lr_scheduler = get_lr_scheduler(optimizer, args)

# init criterion
criterion = torch.nn.BCEWithLogitsLoss()

# init summary writter
log_dir = get_log_dir(optim_name, args)
if not os.path.isdir(log_dir):
Esempio n. 4
0
def main():
    parser = argparse.ArgumentParser()
    ## Required parameters
    parser.add_argument("--config",
                        default=None,
                        type=str,
                        required=True,
                        help="the training config file")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument("--multi_task",
                        action="store_true",
                        help="training with multi task schema")
    parser.add_argument("--debug",
                        action="store_true",
                        help="in debug mode, will not enable wandb log")
    parser.add_argument("--use_wandb",
                        action="store_true",
                        help="whether or not use wandb")
    args = parser.parse_args()
    cfg = parse_cfg(pathlib.Path(args.config))

    # set CUDA_VISIBLE_DEVICES and get num_gpus
    if args.local_rank == -1:  # not distributed
        os.environ["CUDA_VISIBLE_DEVICES"] = cfg["system"][
            "cuda_visible_devices"]
        num_gpus = torch.cuda.device_count()
        args.distributed = False
    else:  # distributed
        torch.cuda.set_device(args.local_rank)
        num_gpus = 1
        args.distributed = True
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
    logger.info(
        "num_gpus: {}, distributed training: {}, 16-bits training: {}".format(
            num_gpus, bool(args.local_rank != -1), cfg["train"]["fp16"]))
    cudnn.benchmark = True

    cfg["train"]["output_dir"] = cfg["train"]["output_dir"] + "/" + \
                                 cfg["train"]["task_name"] + "_" + \
                                 cfg["train"]["model_name"] + "_" + \
                                 cfg["data"]["corpus"]

    output_dir_pl = pathlib.Path(cfg["train"]["output_dir"])
    if output_dir_pl.exists():
        logger.warn(
            "output directory ({}) already exists, continue after 2 seconds..."
            .format(output_dir_pl))
        time.sleep(2)
    else:
        output_dir_pl.mkdir(parents=True, exist_ok=True)

    if not args.debug and args.use_wandb:
        config_dictionary = dict(yaml=cfg, params=args)
        wandb.init(config=config_dictionary,
                   project="nlp-task",
                   dir=cfg["train"]["output_dir"])
        wandb.run.name = cfg["data"]["corpus"] + '-' + cfg["train"][
            "pretrained_tag"] + '-' + time.strftime("%Y-%m-%d %H:%M:%S",
                                                    time.localtime())
        wandb.config.update(args)
        wandb.run.save()

    if cfg["optimizer"]["gradient_accumulation_steps"] < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(cfg["optimizer"]["gradient_accumulation_steps"]))

    # true batch_size in training
    cfg["train"]["batch_size"] = cfg["train"]["batch_size"] // cfg[
        "optimizer"]["gradient_accumulation_steps"]

    # the type of label_map is bidict
    # label_map[x] = xx, label_map.inv[xx] = x
    label_map, num_labels = get_label_map(cfg)
    tokenizer, model = get_tokenizer_and_model(cfg, label_map, num_labels)

    # check model details on wandb
    if not args.debug and args.use_wandb:
        wandb.watch(model)

    num_examples, train_dataloader = get_dataloader(cfg,
                                                    tokenizer,
                                                    num_labels,
                                                    "train",
                                                    debug=args.debug)
    _, eval_dataloader = get_dataloader(cfg,
                                        tokenizer,
                                        num_labels,
                                        "dev",
                                        debug=args.debug)

    # total training steps (including multi epochs)
    num_training_steps = int(
        len(train_dataloader) //
        cfg["optimizer"]["gradient_accumulation_steps"] *
        cfg["train"]["train_epochs"])

    optimizer = AdamW(params=model.parameters(), lr=cfg["optimizer"]["lr"])
    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=cfg["optimizer"]["num_warmup_steps"],
        num_training_steps=num_training_steps)

    scaler = None
    model = model.cuda()
    if cfg["train"]["fp16"] and _use_apex:
        logger.error("using apex amp for fp16...")
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
    elif cfg["train"]["fp16"] and _use_native_amp:
        logger.error("using pytorch native amp for fp16...")
        scaler = torch.cuda.amp.GradScaler()
    elif cfg["train"]["fp16"] and (_use_apex is False
                                   and _use_native_amp is False):
        logger.error("your environment DO NOT support fp16 training...")
        exit()

    if cfg["system"]["distributed"]:
        # TODO distributed debug
        model.cuda(args.local_rank)
        from torch.nn.parallel import DistributedDataParallel as DDP
        model = DDP(model, device_ids=[args.local_rank])
    elif num_gpus > 1:
        model = torch.nn.DataParallel(model)

    # Train
    logger.info("start training on train set")
    epoch = 0
    best_score = -1
    for _ in trange(int(cfg["train"]["train_epochs"]), desc="Epoch"):
        best = False
        # train loop in one epoch
        train_loop(cfg, model, train_dataloader, optimizer, lr_scheduler,
                   num_gpus, epoch, scaler, args.debug, args.use_wandb)
        # begin to evaluate
        logger.info("running evaluation on dev set")
        score = eval_loop(cfg, tokenizer, model, eval_dataloader, label_map,
                          args.debug, args.use_wandb)
        if best_score < score:
            best_score = score
            best = True
        # Save a trained model and the associated configuration
        save_model(cfg, tokenizer, model, best)

        epoch += 1

    # Test Eval
    if args.local_rank == -1 or torch.distributed.get_rank() == 0:
        logger.info("running evaluation on final test set")
        # TODO add stand alone test set
        _, eval_dataloader = get_dataloader(cfg,
                                            tokenizer,
                                            num_labels,
                                            "dev",
                                            debug=args.debug)
        score = eval_loop(cfg, tokenizer, model, eval_dataloader, label_map,
                          args.debug, args.use_wandb)
def init_dataloader(config):
    trainloader, testloader = get_dataloader(dataset=config.dataset,
                                             train_batch_size=config.batch_size,
                                             test_batch_size=128)
    return trainloader, testloader
Esempio n. 6
0
def main(config):
    # init logger
    classes = {
        'cifar10': 10,
        'cifar100': 100,
        'mnist': 10,
        'tiny_imagenet': 200
    }
    logger, writer = init_logger(config)

    # build model
    model = get_network(config.network, config.depth, config.dataset, use_bn=config.get('use_bn', True))
    mask = None
    mb = ModelBase(config.network, config.depth, config.dataset, model)
    mb.cuda()
    if mask is not None:
        mb.register_mask(mask)
        print_mask_information(mb, logger)

    # preprocessing
    # ====================================== get dataloader ======================================
    trainloader, testloader = get_dataloader(config.dataset, config.batch_size, 256, 4, root='/home/wzn/PycharmProjects/GraSP/data')
    # ====================================== fetch configs ======================================
    ckpt_path = config.checkpoint_dir
    num_iterations = config.iterations
    target_ratio = config.target_ratio
    normalize = config.normalize
    # ====================================== fetch exception ======================================
    # exception = get_exception_layers(mb.model, str_to_list(config.exception, ',', int))
    # logger.info('Exception: ')
    #
    # for idx, m in enumerate(exception):
    #     logger.info('  (%d) %s' % (idx, m))

    # ====================================== fetch training schemes ======================================
    ratio = 1 - (1 - target_ratio) ** (1.0 / num_iterations)
    learning_rates = str_to_list(config.learning_rate, ',', float)
    weight_decays = str_to_list(config.weight_decay, ',', float)
    training_epochs = str_to_list(config.epoch, ',', int)
    logger.info('Normalize: %s, Total iteration: %d, Target ratio: %.2f, Iter ratio %.4f.' %
                (normalize, num_iterations, target_ratio, ratio))
    logger.info('Basic Settings: ')
    for idx in range(len(learning_rates)):
        logger.info('  %d: LR: %.5f, WD: %.5f, Epochs: %d' % (idx,
                                                              learning_rates[idx],
                                                              weight_decays[idx],
                                                              training_epochs[idx]))

    # ====================================== start pruning ======================================
    iteration = 0
    for _ in range(1):
        # logger.info('** Target ratio: %.4f, iter ratio: %.4f, iteration: %d/%d.' % (target_ratio,
        #                                                                             ratio,
        #                                                                             1,
        #                                                                             num_iterations))

        mb.model.apply(weights_init)
        print("=> Applying weight initialization(%s)." % config.get('init_method', 'kaiming'))


        # print("Iteration of: %d/%d" % (iteration, num_iterations))
        # masks = GraSP(mb.model, ratio, trainloader, 'cuda',
        #               num_classes=classes[config.dataset],
        #               samples_per_class=config.samples_per_class,
        #               num_iters=config.get('num_iters', 1))
        # iteration = 0
        # print('=> Using GraSP')
        # # ========== register mask ==================
        # mb.register_mask(masks)
        # # ========== save pruned network ============
        # logger.info('Saving..')
        # state = {
        #     'net': mb.model,
        #     'acc': -1,
        #     'epoch': -1,
        #     'args': config,
        #     'mask': mb.masks,
        #     'ratio': mb.get_ratio_at_each_layer()
        # }
        # path = os.path.join(ckpt_path, 'prune_%s_%s%s_r%s_it%d.pth.tar' % (config.dataset,
        #                                                                    config.network,
        #                                                                    config.depth,
        #                                                                    config.target_ratio,
        #                                                                    iteration))
        # torch.save(state, path)

        # # ========== print pruning details ============
        # logger.info('**[%d] Mask and training setting: ' % iteration)
        # print_mask_information(mb, logger)
        # logger.info('  LR: %.5f, WD: %.5f, Epochs: %d' %
        #             (learning_rates[iteration], weight_decays[iteration], training_epochs[iteration]))

        # ========== finetuning =======================
        train_once(mb=mb,
                   net=mb.model,
                   trainloader=trainloader,
                   testloader=testloader,
                   writer=writer,
                   config=config,
                   ckpt_path=ckpt_path,
                   learning_rate=learning_rates[iteration],
                   weight_decay=weight_decays[iteration],
                   num_epochs=training_epochs[iteration],
                   iteration=iteration,
                   logger=logger)