Ejemplo n.º 1
0
def pretrain_finetune(args, logger):
    # data
    source_train_dir, source_valid_dir = default_dir[args.source_dataset]
    source_train_trans, source_valid_trans = hard_trans(args.source_img_size)

    source_trainset = MyDataset(source_train_dir, transform=source_train_trans)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        source_trainset)
    source_trainloader = DataLoader(source_trainset,
                                    batch_size=args.source_bs,
                                    shuffle=False,
                                    num_workers=8,
                                    pin_memory=True,
                                    sampler=train_sampler)

    source_validset = MyDataset(source_valid_dir, transform=source_valid_trans)
    valid_sampler = torch.utils.data.distributed.DistributedSampler(
        source_validset)
    source_validloader = DataLoader(source_validset,
                                    batch_size=args.target_bs,
                                    shuffle=False,
                                    num_workers=8,
                                    pin_memory=True,
                                    sampler=valid_sampler)

    # model, 导入seed_net
    basenet = resnet_nas(args.base_choose)
    if args.source_arch_dir is None:
        assert False, 'need source_arch_dir'
    basenet.load_state_dict(load_normal(args.source_arch_dir))
    seednet = resnet_nas(args.seed_choose)
    seednet = copy.deepcopy(remap_res2arch(basenet, seednet))
    fc = models.__dict__['cos'](source_trainloader.dataset.n_classes,
                                seednet.emb_size, 100.0)
    criterion = models.__dict__['cross_entropy'](1.0)
    seednet = seednet.cuda()
    fc = fc.cuda()
    criterion = criterion.cuda()

    # optimizer
    optimizer = optimizers.__dict__['sgd']((seednet, fc), args.optim_lr)
    scheduler = optimizers.__dict__['warm_cos'](optimizer,
                                                args.source_warmup_epoch,
                                                args.source_max_epoch,
                                                len(source_trainloader))

    # ddp
    seednet = torch.nn.SyncBatchNorm.convert_sync_batchnorm(seednet)
    seednet = ddp(seednet,
                  device_ids=[args.local_rank],
                  output_device=args.local_rank,
                  find_unused_parameters=True)
    fc = ddp(fc,
             device_ids=[args.local_rank],
             output_device=args.local_rank,
             find_unused_parameters=True)

    # train
    for i_epoch in range(args.source_max_epoch):
        seednet.train()
        fc.train()
        source_trainloader.sampler.set_epoch(i_epoch)
        correct = torch.tensor(0.0).cuda()
        total = torch.tensor(0.0).cuda()
        start_time = time.time()

        for i_iter, data in enumerate(source_trainloader):
            img, label = data[:2]
            img, label = img.cuda(), label.cuda()

            optimizer.zero_grad()
            f = seednet(img)
            s = fc(f)
            loss = criterion(s, label)
            loss.backward()
            optimizer.step()
            scheduler.step()

            # acc
            _, predicted = torch.max(s.data, 1)
            correct += predicted.eq(label.data).sum()
            total += label.size(0)

            # print info
            log_fre = len(
                source_trainloader
            ) if args.source_log_fre == -1 else args.source_log_fre
            if (i_iter + 1) % log_fre == 0:
                correct_tmp = correct.clone()
                total_tmp = total.clone()
                dist.reduce(correct_tmp, dst=0, op=dist.ReduceOp.SUM)
                dist.reduce(total_tmp, dst=0, op=dist.ReduceOp.SUM)
                if args.local_rank == 0:
                    eta = (time.time() - start_time) / 60.
                    logger.info(
                        "Training: Epoch[{:0>3}/{:0>3}] "
                        "Iter[{:0>3}/{:0>3}] "
                        "lr: {:.5f} "
                        "Loss: {:.4f} "
                        "Acc:{:.2%} "
                        "Run-T:{:.2f}m".format(
                            i_epoch + 1, args.source_max_epoch, i_iter + 1,
                            len(source_trainloader),
                            optimizer.state_dict()['param_groups'][0]['lr'],
                            loss.cpu().item() / dist.get_world_size(),
                            correct.cpu().item() / total.cpu().item(), eta))
        if args.local_rank == 0:
            if not os.path.exists('tmp_model'):
                os.makedirs('tmp_model')
            torch.save(seednet.state_dict(), 'tmp_model/encoder.pth')
            torch.save(fc.state_dict(), 'tmp_model/fc.pth')

        # valid
        with torch.no_grad():
            seednet.eval()
            fc.eval()
            correct = torch.tensor([0.0]).cuda()
            total = torch.tensor([0.0]).cuda()
            for data in source_validloader:
                img, label = data[:2]
                img, label = img.cuda(), label.cuda()

                feature = seednet(img)
                s = fc(feature)

                # acc
                _, predicted = torch.max(s.data, 1)
                correct += predicted.eq(label.data).sum()
                total += label.size(0)
            dist.reduce(correct, dst=0, op=dist.ReduceOp.SUM)
            dist.reduce(total, dst=0, op=dist.ReduceOp.SUM)
            if args.local_rank == 0:
                logger.info('valid-acc:{:.2%}'.format(correct.cpu().item() /
                                                      total.cpu().item()))
                logger.info('--------------------------')

    # 释放显存防止显存溢出
    del basenet, seednet, fc, criterion, optimizer, scheduler
    gc.collect()
    torch.cuda.empty_cache()
    time.sleep(3)

    source_arch_dir = 'tmp_model/encoder.pth'
    return source_arch_dir
Ejemplo n.º 2
0
def search_step_ga(args, parents, seednet, logger):
    # data
    target_train_dir, target_valid_dir = default_dir[args.target_dataset]
    target_train_trans, target_valid_trans = hard_trans(args.target_img_size)

    target_trainset = MyDataset(target_train_dir, transform=target_train_trans)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        target_trainset)
    target_trainloader = DataLoader(target_trainset,
                                    batch_size=args.target_bs,
                                    shuffle=False,
                                    num_workers=8,
                                    pin_memory=True,
                                    sampler=train_sampler)

    target_validset = MyDataset(target_valid_dir, transform=target_valid_trans)
    valid_sampler = torch.utils.data.distributed.DistributedSampler(
        target_validset)
    target_validloader = DataLoader(target_validset,
                                    batch_size=args.target_bs,
                                    shuffle=False,
                                    num_workers=8,
                                    pin_memory=True,
                                    sampler=valid_sampler)

    parents_acc = []
    for i_ga, v_ga in enumerate(parents):
        if args.local_rank == 0:
            logger.info('current progress:' + str(i_ga + 1) + '/' +
                        str(len(parents)) + '...')
            logger.info(v_ga)

        # model
        arch = resnet_nas(v_ga)
        arch = copy.deepcopy(remap_res2arch(seednet, arch))
        fc = models.__dict__['cos'](target_trainloader.dataset.n_classes,
                                    arch.emb_size, 100.0)
        criterion = models.__dict__['cross_entropy'](1.0)
        arch = arch.cuda()
        fc = fc.cuda()
        criterion = criterion.cuda()

        # optimizer
        optimizer = optimizers.__dict__['sgd']((arch, fc), args.optim_lr)
        scheduler = optimizers.__dict__['warm_cos'](optimizer,
                                                    args.target_warmup_epoch,
                                                    args.target_max_epoch,
                                                    len(target_trainloader))

        # ddp
        arch = torch.nn.SyncBatchNorm.convert_sync_batchnorm(arch)
        arch = ddp(arch,
                   device_ids=[args.local_rank],
                   output_device=args.local_rank,
                   find_unused_parameters=True)
        fc = ddp(fc,
                 device_ids=[args.local_rank],
                 output_device=args.local_rank,
                 find_unused_parameters=True)

        # train
        for i_epoch in range(args.target_max_epoch):
            arch.train()
            fc.train()
            target_trainloader.sampler.set_epoch(i_epoch)
            correct = torch.tensor(0.0).cuda()
            total = torch.tensor(0.0).cuda()
            start_time = time.time()

            for i_iter, data in enumerate(target_trainloader):
                img, label = data[:2]
                img, label = img.cuda(), label.cuda()

                optimizer.zero_grad()
                f = arch(img)
                s = fc(f)
                loss = criterion(s, label)
                loss.backward()
                optimizer.step()
                scheduler.step()

                # acc
                _, predicted = torch.max(s.data, 1)
                correct += predicted.eq(label.data).sum()
                total += label.size(0)

                # print info
                log_fre = len(
                    target_trainloader
                ) if args.target_log_fre == -1 else args.target_log_fre
                if (i_iter + 1) % log_fre == 0:
                    correct_tmp = correct.clone()
                    total_tmp = total.clone()
                    dist.reduce(correct_tmp, dst=0, op=dist.ReduceOp.SUM)
                    dist.reduce(total_tmp, dst=0, op=dist.ReduceOp.SUM)
                    if args.local_rank == 0:
                        eta = (time.time() - start_time) / 60.
                        logger.info(
                            "Training: Epoch[{:0>3}/{:0>3}] "
                            "Iter[{:0>3}/{:0>3}] "
                            "lr: {:.5f} "
                            "Loss: {:.4f} "
                            "Acc:{:.2%} "
                            "Run-T:{:.2f}m".format(
                                i_epoch + 1, args.target_max_epoch, i_iter + 1,
                                len(target_trainloader),
                                optimizer.state_dict()['param_groups'][0]
                                ['lr'],
                                loss.cpu().item() / dist.get_world_size(),
                                correct.cpu().item() / total.cpu().item(),
                                eta))

        # valid
        with torch.no_grad():
            arch.eval()
            fc.eval()
            correct = torch.tensor([0.0]).cuda()
            total = torch.tensor([0.0]).cuda()

            for data in target_validloader:
                img, label = data[:2]
                img, label = img.cuda(), label.cuda()

                feature = arch(img)
                s = fc(feature)

                # acc
                _, predicted = torch.max(s.data, 1)
                correct += predicted.eq(label.data).sum()
                total += label.size(0)

            dist.all_reduce(
                correct,
                op=dist.ReduceOp.SUM)  # 这个地方就是个坑,之前用的reduce,导致各线程不一样,中途会卡住
            dist.all_reduce(total, op=dist.ReduceOp.SUM)
            if args.local_rank == 0:
                logger.info('valid-acc:{:.2%}'.format(correct.cpu().item() /
                                                      total.cpu().item()))
                logger.info('--------------------------')
            parents_acc.append(correct / total)

        # 释放显存防止显存溢出
        del arch, fc, criterion, optimizer, scheduler
        del data, img, label, correct, total
        gc.collect()
        torch.cuda.empty_cache()
        time.sleep(3)
    return parents_acc
Ejemplo n.º 3
0
def main(args):
    # dist init
    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group(backend='nccl', init_method='env://')

    # dataloader
    train_dir, valid_dir = default_dir[args.dataset]
    train_trans, valid_trans = hard_trans(args.img_size)

    trainset = MyDataset(train_dir, transform=train_trans)
    train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
    trainloader = DataLoader(trainset,
                             batch_size=args.bs,
                             shuffle=False,
                             num_workers=8,
                             pin_memory=True,
                             sampler=train_sampler)

    validset = MyDataset(valid_dir, transform=valid_trans)
    valid_sampler = torch.utils.data.distributed.DistributedSampler(validset)
    validloader = DataLoader(validset,
                             batch_size=args.bs,
                             shuffle=False,
                             num_workers=8,
                             pin_memory=True,
                             sampler=valid_sampler)

    # model
    arch = resnet_nas(args.arch_choose)
    if args.arch_dir is not None:
        arch.load_state_dict(load_normal(args.arch_dir))
        print('load success!')
    fc = models.__dict__['cos'](trainloader.dataset.n_classes, arch.emb_size,
                                100.0)
    criterion = models.__dict__['cross_entropy'](1.0)
    arch = arch.cuda()
    fc = fc.cuda()
    criterion = criterion.cuda()

    # optimizer
    optimizer = optimizers.__dict__['sgd']((arch, fc), args.optim_lr)
    scheduler = optimizers.__dict__['warm_cos'](optimizer, args.warmup_epoch,
                                                args.max_epoch,
                                                len(trainloader))

    # ddp
    arch = torch.nn.SyncBatchNorm.convert_sync_batchnorm(arch)
    arch = ddp(arch, device_ids=[args.local_rank], find_unused_parameters=True)
    fc = ddp(fc, device_ids=[args.local_rank], find_unused_parameters=True)

    # log
    if args.local_rank == 0:
        time_str = datetime.strftime(datetime.now(), '%y-%m-%d-%H-%M-%S')
        args.log_dir = os.path.join('logs', args.dataset + '_' + time_str)
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)
        logger = logger_init(args.log_dir)
    else:
        logger = None

    # train and valid
    run(args, arch, fc, criterion, optimizer, scheduler, trainloader,
        validloader, logger)
Ejemplo n.º 4
0
def main(args):
    # dist init
    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group(backend='nccl', init_method='env://')

    # dataset
    rgb_mean = (104, 117, 123)  # bgr order
    dataset = MyDataset(args.txt_path, args.txt_path2,
                        preproc(args.img_size, rgb_mean))
    sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    dataloader = DataLoader(dataset,
                            args.bs,
                            shuffle=False,
                            num_workers=args.num_workers,
                            collate_fn=detection_collate,
                            pin_memory=True,
                            sampler=sampler)

    # net and load
    net = RetinaFace(cfg=cfg_mnet)
    if args.resume_net is not None:
        print('Loading resume network...')
        state_dict = load_normal(args.resume_net)
        net.load_state_dict(state_dict)
        print('Loading success!')
    net = net.cuda()
    # ddp
    net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
    net = ddp(net, device_ids=[args.local_rank], find_unused_parameters=True)

    # optimizer and loss
    optimizer = optim.SGD(net.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=5e-4)
    scheduler = WarmupCosineSchedule(optimizer,
                                     args.warm_epoch, args.max_epoch,
                                     len(dataloader), args.cycles)

    num_classes = 2
    criterion = MultiBoxLoss(num_classes, 0.35, True, 0, True, 7, 0.35, False)

    # priorbox
    priorbox = PriorBox(cfg_mnet, image_size=(args.img_size, args.img_size))
    with torch.no_grad():
        priors = priorbox.forward()
        priors = priors.cuda()

    # save folder
    if args.local_rank == 0:
        time_str = datetime.datetime.strftime(datetime.datetime.now(),
                                              '%y-%m-%d-%H-%M-%S')
        args.save_folder = os.path.join(args.save_folder, time_str)
        if not os.path.exists(args.save_folder):
            os.makedirs(args.save_folder)
        logger = logger_init(args.save_folder)
        logger.info(args)
    else:
        logger = None

    # train
    for i_epoch in range(args.max_epoch):
        net.train()
        dataloader.sampler.set_epoch(i_epoch)
        for i_iter, data in enumerate(dataloader):
            load_t0 = time.time()
            images, targets = data[:2]
            images = images.cuda()
            targets = [anno.cuda() for anno in targets]

            # forward
            out = net(images)

            # backward
            optimizer.zero_grad()
            loss_l, loss_c, loss_landm = criterion(out, priors, targets)
            loss = cfg_mnet['loc_weight'] * loss_l + loss_c + loss_landm
            loss.backward()
            optimizer.step()
            scheduler.step()

            # print info
            load_t1 = time.time()
            batch_time = load_t1 - load_t0
            eta = int(batch_time * (len(dataloader) *
                                    (args.max_epoch - i_epoch) - i_iter))
            if args.local_rank == 0:
                logger.info(
                    'Epoch:{}/{} || Iter: {}/{} || '
                    'Loc: {:.4f} Cla: {:.4f} Landm: {:.4f} || '
                    'LR: {:.8f} || '
                    'Batchtime: {:.4f} s || '
                    'ETA: {}'.format(
                        i_epoch + 1, args.max_epoch, i_iter + 1,
                        len(dataloader), loss_l.item(), loss_c.item(),
                        loss_landm.item(),
                        optimizer.state_dict()['param_groups'][0]['lr'],
                        batch_time, str(datetime.timedelta(seconds=eta))))
        if (i_epoch + 1) % args.save_fre == 0:
            if args.local_rank == 0:
                save_name = 'mobile0.25_' + str(i_epoch + 1) + '.pth'
                torch.save(net.state_dict(),
                           os.path.join(args.save_folder, save_name))
Ejemplo n.º 5
0
def main(args):
    # dist init
    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group(backend='nccl', init_method='env://')

    # dataloader
    train_dir, valid_dir = default_dir[args.dataset]
    train_trans, valid_trans = hard_trans(args.img_size)

    trainset = MyDataset(train_dir, transform=train_trans)
    train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
    trainloader = DataLoader(trainset,
                             batch_size=args.bs,
                             shuffle=False,
                             num_workers=8,
                             pin_memory=True,
                             sampler=train_sampler)

    validset = MyDataset(valid_dir, transform=valid_trans)
    valid_sampler = torch.utils.data.distributed.DistributedSampler(validset)
    validloader = DataLoader(validset,
                             batch_size=args.bs,
                             shuffle=False,
                             num_workers=8,
                             pin_memory=True,
                             sampler=valid_sampler)

    # model
    supernet_choose = choose_rand(([6], [8], [12], [6]), (1.5, ), [(1, 7)])
    supernet = resnet_nas(supernet_choose)
    fc = models.__dict__['cos'](trainloader.dataset.n_classes,
                                supernet.emb_size, 100.0)
    criterion = models.__dict__['cross_entropy'](1.0)
    supernet = supernet.cuda()
    fc = fc.cuda()
    criterion = criterion.cuda()

    # optimizer
    optimizer = optimizers.__dict__['sgd']((supernet, fc), args.optim_lr)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     args.optim_step,
                                                     gamma=0.1)

    # ddp
    supernet = torch.nn.SyncBatchNorm.convert_sync_batchnorm(supernet)
    supernet = ddp(supernet,
                   device_ids=[args.local_rank],
                   find_unused_parameters=True)
    fc = ddp(fc, device_ids=[args.local_rank], find_unused_parameters=True)

    # log
    if args.local_rank == 0:
        time_str = datetime.strftime(datetime.now(), '%y-%m-%d-%H-%M-%S')
        args.log_dir = os.path.join('logs', args.dataset + '_' + time_str)
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir)
        logger = logger_init(args.log_dir)
    else:
        logger = None
    if args.local_rank == 0:
        logger.info(args)
        logger.info('n_classes:%d' % trainloader.dataset.n_classes)

    # train and valid
    run(args, supernet, fc, criterion, optimizer, scheduler, trainloader,
        validloader, logger)
    if args.local_rank == 0:
        logger.info(args.log_dir)
Ejemplo n.º 6
0
def main(args):
    # dist init
    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    print(torch.cuda.device_count(), args.local_rank)

    # data
    train_transform = tv.transforms.Compose([])
    if args.data_augmentation:
        train_transform.transforms.append(
            tv.transforms.RandomCrop(32, padding=4))
        train_transform.transforms.append(tv.transforms.RandomHorizontalFlip())
    train_transform.transforms.append(tv.transforms.ToTensor())
    normalize = tv.transforms.Normalize(
        mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
        std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
    train_transform.transforms.append(normalize)

    test_transform = tv.transforms.Compose(
        [tv.transforms.ToTensor(), normalize])

    train_dataset = tv.datasets.CIFAR10(root='data/',
                                        train=True,
                                        transform=train_transform,
                                        download=True)

    test_dataset = tv.datasets.CIFAR10(root='data/',
                                       train=False,
                                       transform=test_transform,
                                       download=True)
    sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=args.bs,
                                               shuffle=False,
                                               pin_memory=True,
                                               num_workers=4,
                                               sampler=sampler)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=args.bs,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=4)

    # net
    net = tv.models.resnet18(num_classes=10)
    net = net.cuda()

    # optimizer and loss
    optimizer = torch.optim.SGD(net.parameters(),
                                lr=0.1,
                                momentum=0.9,
                                weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [50, 80], 0.1)
    criterion = torch.nn.CrossEntropyLoss().cuda()

    net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
    net = ddp(net, device_ids=[args.local_rank], find_unused_parameters=True)

    # train
    for i_epoch in range(100):
        net.train()
        time_s = time.time()
        train_loader.sampler.set_epoch(i_epoch)
        for i_iter, data in enumerate(train_loader):
            img, label = data
            img, label = img.cuda(), label.cuda()

            optimizer.zero_grad()
            feat = net(img)
            loss = criterion(feat, label)
            loss.backward()
            optimizer.step()
            time_e = time.time()

            if args.local_rank == 1:
                print('Epoch:{:3}/100 || Iter: {:4}/{} || '
                      'Loss: {:2.4f} '
                      'ETA: {:.2f}min'.format(
                          i_epoch + 1, i_iter + 1, len(train_loader),
                          loss.item(), (time_e - time_s) * (100 - i_epoch) *
                          len(train_loader) / (i_iter + 1) / 60))
        scheduler.step()
Ejemplo n.º 7
0
def run(gpu_id, args):

    ## parameters for multi-processing
    print('using gpu', gpu_id)
    dist.init_process_group(backend='nccl',
                            init_method='env://',
                            world_size=args.world_size,
                            rank=gpu_id)

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    # device = load.device(args.gpu)
    device = torch.device(gpu_id)

    args.gpu_id = gpu_id

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)

    ## need to change the workers for loading the data
    args.workers = int((args.workers + 4 - 1) / 4)
    print('Adjusted dataloader worker number is ', args.workers)

    # prune_loader = load.dataloader(args.dataset, args.prune_batch_size, True, args.workers,
    #                 args.prune_dataset_ratio * num_classes, world_size=args.world_size, rank=gpu_id)
    prune_loader, _ = load.dataloader(args.dataset, args.prune_batch_size,
                                      True, args.workers,
                                      args.prune_dataset_ratio * num_classes)

    ## need to divide the training batch size for each GPU
    args.train_batch_size = int(args.train_batch_size / args.gpu_count)
    train_loader, train_sampler = load.dataloader(args.dataset,
                                                  args.train_batch_size,
                                                  True,
                                                  args.workers,
                                                  args=args)
    # args.test_batch_size = int(args.test_batch_size/args.gpu_count)
    test_loader, _ = load.dataloader(args.dataset, args.test_batch_size, False,
                                     args.workers)

    print("data loader batch size (prune::train::test) is {}::{}::{}".format(
        prune_loader.batch_size, train_loader.batch_size,
        test_loader.batch_size))

    log_filename = '{}/{}'.format(args.result_dir, 'result.log')
    fout = open(log_filename, 'w')
    fout.write('start!\n')

    if args.compression_list == []:
        args.compression_list.append(args.compression)
    if args.pruner_list == []:
        args.pruner_list.append(args.pruner)

    ## Model, Loss, Optimizer ##
    print('Creating {}-{} model.'.format(args.model_class, args.model))

    # load the pre-defined model from the utils
    model = load.model(args.model, args.model_class)(input_shape, num_classes,
                                                     args.dense_classifier,
                                                     args.pretrained)

    ## wrap model with distributed dataparallel module
    torch.cuda.set_device(gpu_id)
    # model = model.to(device)
    model.cuda(gpu_id)
    model = ddp(model, device_ids=[gpu_id])

    ## don't need to move the loss to the GPU as it contains no parameters
    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model),
                          lr=args.lr,
                          weight_decay=args.weight_decay,
                          **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.lr_drops,
                                                     gamma=args.lr_drop_rate)

    ## Pre-Train ##
    print('Pre-Train for {} epochs.'.format(args.pre_epochs))
    pre_result = train_eval_loop(model,
                                 loss,
                                 optimizer,
                                 scheduler,
                                 train_loader,
                                 test_loader,
                                 device,
                                 args.pre_epochs,
                                 args.verbose,
                                 train_sampler=train_sampler)
    print('Pre-Train finished!')

    ## Save Original ##
    torch.save(model.state_dict(),
               "{}/pre_train_model_{}.pt".format(args.result_dir, gpu_id))
    torch.save(optimizer.state_dict(),
               "{}/pre_train_optimizer_{}.pt".format(args.result_dir, gpu_id))
    torch.save(scheduler.state_dict(),
               "{}/pre_train_scheduler_{}.pt".format(args.result_dir, gpu_id))

    if not args.unpruned:
        for compression in args.compression_list:
            for p in args.pruner_list:
                # Reset Model, Optimizer, and Scheduler
                print('compression ratio: {} ::: pruner: {}'.format(
                    compression, p))
                model.load_state_dict(
                    torch.load("{}/pre_train_model_{}.pt".format(
                        args.result_dir, gpu_id),
                               map_location=device))
                optimizer.load_state_dict(
                    torch.load("{}/pre_train_optimizer_{}.pt".format(
                        args.result_dir, gpu_id),
                               map_location=device))
                scheduler.load_state_dict(
                    torch.load("{}/pre_train_scheduler_{}.pt".format(
                        args.result_dir, gpu_id),
                               map_location=device))

                ## Prune ##
                print('Pruning with {} for {} epochs.'.format(
                    p, args.prune_epochs))
                pruner = load.pruner(p)(generator.masked_parameters(
                    model, args.prune_bias, args.prune_batchnorm,
                    args.prune_residual))
                sparsity = 10**(-float(compression))
                prune_loop(model, loss, pruner, prune_loader, device, sparsity,
                           args.compression_schedule, args.mask_scope,
                           args.prune_epochs, args.reinitialize,
                           args.prune_train_mode, args.shuffle, args.invert)

                ## Post-Train ##
                print('Post-Training for {} epochs.'.format(args.post_epochs))
                post_train_start_time = timeit.default_timer()
                post_result = train_eval_loop(model,
                                              loss,
                                              optimizer,
                                              scheduler,
                                              train_loader,
                                              test_loader,
                                              device,
                                              args.post_epochs,
                                              args.verbose,
                                              train_sampler=train_sampler)
                post_train_end_time = timeit.default_timer()
                print("Post Training time: {:.4f}s".format(
                    post_train_end_time - post_train_start_time))

                ## Display Results ##
                frames = [
                    pre_result.head(1),
                    pre_result.tail(1),
                    post_result.head(1),
                    post_result.tail(1)
                ]
                train_result = pd.concat(
                    frames, keys=['Init.', 'Pre-Prune', 'Post-Prune', 'Final'])
                prune_result = metrics.summary(
                    model, pruner.scores,
                    metrics.flop(model, input_shape, device),
                    lambda p: generator.prunable(p, args.prune_batchnorm, args.
                                                 prune_residual))
                total_params = int(
                    (prune_result['sparsity'] * prune_result['size']).sum())
                possible_params = prune_result['size'].sum()
                total_flops = int(
                    (prune_result['sparsity'] * prune_result['flops']).sum())
                possible_flops = prune_result['flops'].sum()
                print("Train results:\n", train_result)
                # print("Prune results:\n", prune_result)
                # print("Parameter Sparsity: {}/{} ({:.4f})".format(total_params, possible_params, total_params / possible_params))
                # print("FLOP Sparsity: {}/{} ({:.4f})".format(total_flops, possible_flops, total_flops / possible_flops))

                ## recording testing time for task 2 ##
                # evaluating the model, including some data gathering overhead
                # eval(model, loss, test_loader, device, args.verbose)
                model.eval()
                start_time = timeit.default_timer()
                with torch.no_grad():
                    for data, target in test_loader:
                        data, target = data.to(device), target.to(device)
                        temp_eval_out = model(data)
                end_time = timeit.default_timer()
                print("Testing time: {:.4f}s".format(end_time - start_time))

                fout.write('compression ratio: {} ::: pruner: {}'.format(
                    compression, p))
                fout.write('Train results:\n {}\n'.format(train_result))
                fout.write('Prune results:\n {}\n'.format(prune_result))
                fout.write('Parameter Sparsity: {}/{} ({:.4f})\n'.format(
                    total_params, possible_params,
                    total_params / possible_params))
                fout.write("FLOP Sparsity: {}/{} ({:.4f})\n".format(
                    total_flops, possible_flops, total_flops / possible_flops))
                fout.write("Testing time: {}s\n".format(end_time - start_time))
                fout.write("remaining weights: \n{}\n".format(
                    (prune_result['sparsity'] * prune_result['size'])))
                fout.write('flop each layer: {}\n'.format(
                    (prune_result['sparsity'] *
                     prune_result['flops']).values.tolist()))
                ## Save Results and Model ##
                if args.save:
                    print('Saving results.')
                    if not os.path.exists('{}/{}'.format(
                            args.result_dir, compression)):
                        os.makedirs('{}/{}'.format(args.result_dir,
                                                   compression))
                    # pre_result.to_pickle("{}/{}/pre-train.pkl".format(args.result_dir, compression))
                    # post_result.to_pickle("{}/{}/post-train.pkl".format(args.result_dir, compression))
                    # prune_result.to_pickle("{}/{}/compression.pkl".format(args.result_dir, compression))
                    # torch.save(model.state_dict(), "{}/{}/model.pt".format(args.result_dir, compression))
                    # torch.save(optimizer.state_dict(),
                    #         "{}/{}/optimizer.pt".format(args.result_dir, compression))
                    # torch.save(scheduler.state_dict(),
                    #         "{}/{}/scheduler.pt".format(args.result_dir, compression))

    else:
        print('Staring Unpruned NN training')
        print('Training for {} epochs.'.format(args.post_epochs))
        model.load_state_dict(
            torch.load("{}/pre_train_model.pt".format(args.result_dir),
                       map_location=device))
        optimizer.load_state_dict(
            torch.load("{}/pre_train_optimizer.pt".format(args.result_dir),
                       map_location=device))
        scheduler.load_state_dict(
            torch.load("{}/pre_train_scheduler.pt".format(args.result_dir),
                       map_location=device))

        train_start_time = timeit.default_timer()
        result = train_eval_loop(model,
                                 loss,
                                 optimizer,
                                 scheduler,
                                 train_loader,
                                 test_loader,
                                 device,
                                 args.post_epochs,
                                 args.verbose,
                                 train_sampler=train_sampler)
        train_end_time = timeit.default_timer()
        frames = [result.head(1), result.tail(1)]
        train_result = pd.concat(frames, keys=['Init.', 'Final'])
        print('Train results:\n', train_result)

    fout.close()

    dist.destroy_process_group()