示例#1
0
 def test_single_input(self):
     model = SingleInputNet()
     input = (1, 28, 28)
     result, (total_params, trainable_params) = summary_string(
         model, input, device="cpu")
     self.assertEqual(type(result), str)
     self.assertEqual(total_params, 21840)
     self.assertEqual(trainable_params, 21840)
def print_model_summary(engine: Engine, model: nn.Module) -> None:
    try:
        import torchsummary
        device = engine.state.device
        s, _ = torchsummary.summary_string(model, (3, 96, 96), device=device)
        print(s)
    except:
        print(model)
        n = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print('Number of parameters: {}'.format(n))
示例#3
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu
    num_classes = 10

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.node_rank == -1:
            args.node_rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.node_rank = args.node_rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.node_rank)

    # create model
    if args.model:
        if args.arch or args.pretrained:
            warnings.warn(
                "Ignoring arguments \"arch\" and \"pretrained\" when creating model..."
            )
        model = None
        saved_checkpoint = torch.load(args.model)
        if isinstance(saved_checkpoint, nn.Module):
            model = saved_checkpoint
        elif "model" in saved_checkpoint:
            model = saved_checkpoint["model"]
        else:
            raise Exception("Unable to load model from " + args.model)

        if (args.gpu is not None):
            model.cuda(args.gpu)
    elif args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    if args.weights:
        saved_weights = torch.load(args.weights)
        if isinstance(saved_weights, nn.Module):
            state_dict = saved_weights.state_dict()
        elif "state_dict" in saved_weights:
            state_dict = saved_weights["state_dict"]
        else:
            state_dict = saved_weights

        try:
            model.load_state_dict(state_dict)
        except:
            # create new OrderedDict that does not contain module.
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k.replace("module.", "")
                new_state_dict[name] = v

            model.load_state_dict(new_state_dict)

    print("Original Model:")
    print(model)
    print("\n\n")

    # create decomposition configuration
    decomp_config = {
        "criterion": None,
        "threshold": args.threshold,
        "rank": args.rank,
        "exclude_first_conv": args.exclude_first_conv,
        "exclude_linears": args.exclude_linears,
        "conv_ranks": args.conv_ranks,
        "mask_conv_layers": None
    }

    if args.decompose:
        print("Decomposing...")

        model = decompose_model(model, args.decompose_type, decomp_config)
        print("\n\n")

        print("Decomposed Model:")
        print(model)
        print("\n\n")

    if args.reconstruct:
        print("Reconstructing...")
        model = reconstruct_model(model, args.decompose_type)
        print("\n\n")

        print("Reconstructed Model:")
        print(model)
        print("\n\n")

    # print summary of model before parallellizing among different GPUs
    model_summary = None
    try:
        model_summary, model_params_info = torchsummary.summary_string(
            model, input_size=(3, 32, 32))
        print(model_summary)
    except Exception as e:
        warnings.warn("Unable to obtain summary of model")
        print(e)

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(args.workers / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if (args.arch.startswith('alexnet')) and args.pretrained != "cifar10":
            if (hasattr(model, 'features')):
                model.features = torch.nn.DataParallel(model.features)
                model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion)
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    # define optimizer
    optimizer = None
    if (args.optimizer.lower() == "sgd"):
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    elif (args.optimizer.lower() == "adadelta"):
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         args.lr,
                                         weight_decay=args.weight_decay)
    elif (args.optimizer.lower() == "adagrad"):
        optimizer = torch.optim.Adagrad(model.parameters(),
                                        args.lr,
                                        weight_decay=args.weight_decay)
    elif (args.optimizer.lower() == "adam"):
        optimizer = torch.optim.Adam(model.parameters(),
                                     args.lr,
                                     weight_decay=args.weight_decay)
    elif (args.optimizer.lower() == "rmsprop"):
        optimizer = torch.optim.RMSprop(model.parameters(),
                                        args.lr,
                                        weight_decay=args.weight_decay)
    elif (args.optimizer.lower() == "radam"):
        optimizer = optim.RAdam(model.parameters(),
                                args.lr,
                                weight_decay=args.weight_decay)
    elif (args.optimizer.lower() == "ranger"):
        optimizer = optim.Ranger(model.parameters(),
                                 args.lr,
                                 weight_decay=args.weight_decay)
    else:
        raise ValueError("Optimizer type: ", args.optimizer,
                         " is not supported or known")

    lr_scheduler = None
    if args.opt_ckpt:
        warnings.warn(
            "Ignoring arguments \"lr\", \"momentum\", \"weight_decay\", and \"lr_schedule\""
        )

        opt_ckpt = torch.load(args.opt_ckpt)
        if 'optimizer' in opt_ckpt:
            opt_ckpt = opt_ckpt['optimizer']
        optimizer.load_state_dict(opt_ckpt)

        if 'lr_scheduler' in opt_ckpt:
            lr_scheduler = opt_ckpt['lr_scheduler']

    # define learning rate schedule
    if (args.lr_schedule == 'MultiStepLR'):
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[80, 120, 160, 180], gamma=0.1)
    elif (args.lr_schedule == 'StepLR'):
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                       args.lr_step_size,
                                                       gamma=0.1)
    else:
        lr_scheduler = None

    if args.arch in ['resnet1202', 'resnet110']:
        # for resnet1202 original paper uses lr=0.01 for first 400 minibatches for warm-up
        # then switch back. In this implementation it will correspond for first epoch.
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.lr * 0.1

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            if 'lr_scheduler' in checkpoint and checkpoint[
                    'lr_scheduler'] is not None:
                lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # optionally reset weights
    def reset_weights(m):
        reset_parameters = getattr(m, "reset_parameters", None)
        if callable(reset_parameters):
            m.reset_parameters()

    if args.reset_weights:
        model.apply(reset_weights)

    cudnn.benchmark = True

    # name model directory
    if (args.decompose):
        decompose_label = args.decompose_type + "_decompose"
    elif (args.reconstruct):
        decompose_label = "reconstruct"
    else:
        decompose_label = "no_decompose"

    arch_name = "generic" if (args.arch is None
                              or len(args.arch) == 0) else args.arch
    if args.desc is not None and len(args.desc) > 0:
        model_name = '%s/%s_%s' % (arch_name, args.desc, decompose_label)
    else:
        model_name = '%s/%s' % (arch_name, decompose_label)

    if (args.save_model):
        model_dir = os.path.join(
            os.path.join(os.path.join(os.getcwd(), "models"), "cifar10"),
            model_name)
        if not os.path.isdir(model_dir):
            os.makedirs(model_dir, exist_ok=True)

        with open(os.path.join(model_dir, 'command_args.txt'),
                  'w') as command_args_file:
            for arg, value in sorted(vars(args).items()):
                command_args_file.write(arg + ": " + str(value) + "\n")

        with open(os.path.join(model_dir, 'model.txt'), 'w') as model_txt_file:
            with redirect_stdout(model_txt_file):
                print(model)

        with open(os.path.join(model_dir, 'model_summary.txt'),
                  'w') as summary_file:
            with redirect_stdout(summary_file):
                if (model_summary is not None):
                    print(model_summary)
                else:
                    warnings.warn("Unable to obtain summary of model")

    # Data loading code
    data_dir = "~/pytorch_datasets"
    os.makedirs(model_dir, exist_ok=True)

    normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                     std=[0.2023, 0.1994, 0.2010])

    train_dataset = datasets.CIFAR10(root=data_dir,
                                     train=True,
                                     transform=transforms.Compose([
                                         transforms.RandomHorizontalFlip(),
                                         transforms.RandomCrop(size=32,
                                                               padding=4),
                                         transforms.ToTensor(),
                                         normalize,
                                     ]),
                                     download=True)

    train_dataset_downsized = datasets.CIFAR10(
        root=data_dir,
        train=True,
        transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=32, padding=4),
            transforms.Resize((16, 16)),
            transforms.ToTensor(),
            normalize,
        ]),
        download=True)

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    if args.downsize_freq is not None:
        train_loader_downsized = torch.utils.data.DataLoader(
            train_dataset_downsized,
            batch_size=args.batch_size *
            int(1 if args.downsize_bm is None else args.downsize_bm),
            shuffle=(train_sampler is None),
            num_workers=args.workers,
            pin_memory=True,
            sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
        root=data_dir,
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    start_time = time.time()

    if args.evaluate:
        start_log_time = time.time()
        val_log = validate(val_loader, model, criterion, args)
        val_log = [val_log]

        with open(os.path.join(model_dir, "test_log.csv"),
                  "w") as test_log_file:
            test_log_csv = csv.writer(test_log_file)
            test_log_csv.writerow(
                ['test_loss', 'test_top1_acc', 'test_time', 'cumulative_time'])
            test_log_csv.writerows(val_log +
                                   [(time.time() - start_log_time, )])
    else:
        train_log = []

        with open(os.path.join(model_dir, "train_log.csv"),
                  "w") as train_log_file:
            train_log_csv = csv.writer(train_log_file)
            train_log_csv.writerow([
                'epoch', 'train_loss', 'train_top1_acc', 'train_time',
                'grad_first_layer', 'test_loss', 'test_top1_acc', 'test_time',
                'cumulative_time'
            ])

        # initialize lr scheduler according to start_epoch
        if (args.lr_schedule):
            for i in range(args.start_epoch):
                lr_scheduler.step()

        start_log_time = time.time()
        for epoch in range(args.start_epoch, args.epochs):
            if args.distributed:
                train_sampler.set_epoch(epoch)

            # train for one epoch
            if args.downsize_freq is not None and epoch % args.downsize_freq == 0:
                train_loader_chosen = train_loader_downsized
                for param_group in optimizer.param_groups:
                    param_group['lr'] /= args.downsize_lr_reduction
            else:
                train_loader_chosen = train_loader

            print('current lr {:.4e}'.format(optimizer.param_groups[0]['lr']))
            train_epoch_log = train(train_loader_chosen, model, criterion,
                                    optimizer, epoch, args)

            # update learning rate
            if args.downsize_freq is not None and epoch % args.downsize_freq == 0:
                for param_group in optimizer.param_groups:
                    param_group['lr'] *= args.downsize_lr_reduction

            if (args.lr_schedule):
                lr_scheduler.step()

            if args.arch in ['resnet1202', 'resnet110'] and epoch == 0:
                # for resnet1202 original paper uses lr=0.01 for first 400 minibatches for warm-up
                # then switch back. In this implementation it will correspond for first epoch.
                for param_group in optimizer.param_groups:
                    param_group['lr'] = args.lr

            # evaluate on validation set
            val_epoch_log = validate(val_loader, model, criterion, args)
            acc1 = val_epoch_log[1]

            # append to log
            with open(os.path.join(model_dir, "train_log.csv"),
                      "a") as train_log_file:
                train_log_csv = csv.writer(train_log_file)
                train_log_csv.writerow(
                    ((epoch, ) + tuple(train_epoch_log.values()) +
                     val_epoch_log + (time.time() - start_log_time, )))

            # remember best acc@1 and save checkpoint
            is_best = acc1 > best_acc1
            best_acc1 = max(acc1, best_acc1)

            if (args.print_weights):
                os.makedirs(os.path.join(model_dir, 'weights_logs'),
                            exist_ok=True)
                with open(
                        os.path.join(model_dir, 'weights_logs',
                                     'weights_log_' + str(epoch) + '.txt'),
                        'w') as weights_log_file:
                    with redirect_stdout(weights_log_file):
                        # Log model's state_dict
                        print("Model's state_dict:")
                        # TODO: Use checkpoint above
                        for param_tensor in model.state_dict():
                            print(param_tensor, "\t",
                                  model.state_dict()[param_tensor].size())
                            print(model.state_dict()[param_tensor])
                            print("")

            if not args.multiprocessing_distributed or (
                    args.multiprocessing_distributed
                    and args.node_rank % ngpus_per_node == 0):
                if is_best:
                    try:
                        if (args.save_model):
                            torch.save(model,
                                       os.path.join(model_dir, "model.pth"))
                    except:
                        warnings.warn("Unable to save model.pth")
                    try:
                        if (args.save_model):
                            torch.save(model.state_dict(),
                                       os.path.join(model_dir, "weights.pth"))
                    except:
                        warnings.warn("Unable to save weights.pth")

                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'best_acc1': best_acc1,
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler,
                    }, is_best, model_dir)

    end_time = time.time()
    print("Total Time:", end_time - start_time)

    if (args.print_weights):
        with open(os.path.join(model_dir, 'weights_log.txt'),
                  'w') as weights_log_file:
            with redirect_stdout(weights_log_file):
                # Log model's state_dict
                print("Model's state_dict:")
                # TODO: Use checkpoint above
                for param_tensor in model.state_dict():
                    print(param_tensor, "\t",
                          model.state_dict()[param_tensor].size())
                    print(model.state_dict()[param_tensor])
                    print("")
示例#4
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu
    num_classes = 10

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)
    # create model
    if args.model:
        if args.arch or args.pretrained:
            print(
                "WARNING: Ignoring arguments \"arch\" and \"pretrained\" when creating model..."
            )
        model = None
        saved_checkpoint = torch.load(args.model)
        if isinstance(saved_checkpoint, nn.Module):
            model = saved_checkpoint
        elif "model" in saved_checkpoint:
            model = saved_checkpoint["model"]
        else:
            raise Exception("Unable to load model from " + args.model)

        if (args.gpu is not None):
            model.cuda(args.gpu)
    elif args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    model_rounded = None

    #TODO: add option for finetune vs. feature extraction that only work if pretrained weights are imagenet
    if args.freeze and args.pretrained != "none":
        for param in model.parameters():
            param.requires_grad = False

    if args.weights:
        saved_weights = torch.load(args.weights)
        if isinstance(saved_weights, nn.Module):
            state_dict = saved_weights.state_dict()
        elif "state_dict" in saved_weights:
            state_dict = saved_weights["state_dict"]
        else:
            state_dict = saved_weights

        try:
            model.load_state_dict(state_dict)
        except:
            # create new OrderedDict that does not contain module.
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[7:]  # remove module.
                new_state_dict[name] = v

            model.load_state_dict(new_state_dict)

    if args.shift_depth > 0:
        model, _ = convert_to_shift(model,
                                    args.shift_depth,
                                    args.shift_type,
                                    convert_weights=(args.pretrained != "none"
                                                     or args.weights),
                                    use_kernel=args.use_kernel,
                                    rounding=args.rounding,
                                    weight_bits=args.weight_bits)
    elif args.use_kernel and args.shift_depth == 0:
        model = convert_to_unoptimized(model)

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(args.workers / ngpus_per_node)
            #TODO: Allow args.gpu to be a list of IDs
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if (args.arch.startswith('alexnet')):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion)
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    # create optimizer
    model_other_params = []
    model_sign_params = []
    model_shift_params = []

    for name, param in model.named_parameters():
        if (name.endswith(".sign")):
            model_sign_params.append(param)
        elif (name.endswith(".shift")):
            model_shift_params.append(param)
        else:
            model_other_params.append(param)

    params_dict = [{
        "params": model_other_params
    }, {
        "params": model_sign_params,
        'lr': args.lr_sign if args.lr_sign is not None else args.lr,
        'weight_decay': 0
    }, {
        "params": model_shift_params,
        'lr': args.lr,
        'weight_decay': 0
    }]

    # define optimizer
    optimizer = None
    if (args.optimizer.lower() == "sgd"):
        optimizer = torch.optim.SGD(params_dict,
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    elif (args.optimizer.lower() == "adadelta"):
        optimizer = torch.optim.Adadelta(params_dict,
                                         args.lr,
                                         weight_decay=args.weight_decay)
    elif (args.optimizer.lower() == "adagrad"):
        optimizer = torch.optim.Adagrad(params_dict,
                                        args.lr,
                                        weight_decay=args.weight_decay)
    elif (args.optimizer.lower() == "adam"):
        optimizer = torch.optim.Adam(params_dict,
                                     args.lr,
                                     weight_decay=args.weight_decay)
    elif (args.optimizer.lower() == "rmsprop"):
        optimizer = torch.optim.RMSprop(params_dict,
                                        args.lr,
                                        weight_decay=args.weight_decay)
    elif (args.optimizer.lower() == "radam"):
        optimizer = optim.RAdam(params_dict,
                                args.lr,
                                weight_decay=args.weight_decay)
    elif (args.optimizer.lower() == "ranger"):
        optimizer = optim.Ranger(params_dict,
                                 args.lr,
                                 weight_decay=args.weight_decay)
    else:
        raise ValueError("Optimizer type: ", args.optimizer,
                         " is not supported or known")

    # define learning rate schedule
    if (args.lr_schedule):
        if (args.lr_step_size is not None):
            lr_scheduler = torch.optim.lr_scheduler.StepLR(
                optimizer, step_size=args.lr_step_size)
        else:
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer,
                milestones=[80, 120, 160, 180],
                last_epoch=args.start_epoch - 1)

    if args.arch in ['resnet1202', 'resnet110']:
        # for resnet1202 original paper uses lr=0.01 for first 400 minibatches for warm-up
        # then switch back. In this implementation it will correspond for first epoch.
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.lr * 0.1

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # if evaluating round weights to ensure that the results are due to powers of 2 weights
    if (args.evaluate):
        model = round_shift_weights(model)

    cudnn.benchmark = True

    model_summary = None
    try:
        model_summary, model_params_info = torchsummary.summary_string(
            model, input_size=(3, 32, 32))
        print(model_summary)
        print(
            "WARNING: The summary function reports duplicate parameters for multi-GPU case"
        )
    except:
        print("WARNING: Unable to obtain summary of model")

    # name model sub-directory "shift_all" if all layers are converted to shift layers
    conv2d_layers_count = count_layer_type(
        model, nn.Conv2d) + count_layer_type(model,
                                             unoptimized.UnoptimizedConv2d)
    linear_layers_count = count_layer_type(
        model, nn.Linear) + count_layer_type(model,
                                             unoptimized.UnoptimizedLinear)
    if (args.shift_depth > 0):
        if (args.shift_type == 'Q'):
            shift_label = "shift_q"
        else:
            shift_label = "shift_ps"
    else:
        shift_label = "shift"

    if (conv2d_layers_count == 0 and linear_layers_count == 0):
        shift_label += "_all"
    else:
        shift_label += "_%s" % (args.shift_depth)

    if (args.shift_depth > 0):
        shift_label += "_wb_%s" % (args.weight_bits)

    if (args.desc is not None and len(args.desc) > 0):
        desc_label = "_%s" % (args.desc)
    else:
        desc_label = ""

    model_name = '%s/%s%s' % (args.arch, shift_label, desc_label)

    if (args.save_model):
        model_dir = os.path.join(
            os.path.join(os.path.join(os.getcwd(), "models"), "cifar10"),
            model_name)
        if not os.path.isdir(model_dir):
            os.makedirs(model_dir, exist_ok=True)

        with open(os.path.join(model_dir, 'command_args.txt'),
                  'w') as command_args_file:
            for arg, value in sorted(vars(args).items()):
                command_args_file.write(arg + ": " + str(value) + "\n")

        with open(os.path.join(model_dir, 'model_summary.txt'),
                  'w') as summary_file:
            with redirect_stdout(summary_file):
                if (model_summary is not None):
                    print(model_summary)
                    print(
                        "WARNING: The summary function reports duplicate parameters for multi-GPU case"
                    )
                else:
                    print("WARNING: Unable to obtain summary of model")

    # Data loading code
    data_dir = "~/pytorch_datasets"
    os.makedirs(model_dir, exist_ok=True)

    normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                     std=[0.2023, 0.1994, 0.2010])

    train_dataset = datasets.CIFAR10(root=data_dir,
                                     train=True,
                                     transform=transforms.Compose([
                                         transforms.RandomHorizontalFlip(),
                                         transforms.RandomCrop(size=32,
                                                               padding=4),
                                         transforms.ToTensor(),
                                         normalize,
                                     ]),
                                     download=True)

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
        root=data_dir,
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    start_time = time.time()

    if args.evaluate:
        val_log = validate(val_loader, model, criterion, args)
        val_log = [val_log]

        with open(os.path.join(model_dir, "test_log.csv"),
                  "w") as test_log_file:
            test_log_csv = csv.writer(test_log_file)
            test_log_csv.writerow(['test_loss', 'test_top1_acc', 'test_time'])
            test_log_csv.writerows(val_log)
    else:
        train_log = []

        with open(os.path.join(model_dir, "train_log.csv"),
                  "w") as train_log_file:
            train_log_csv = csv.writer(train_log_file)
            train_log_csv.writerow([
                'epoch', 'train_loss', 'train_top1_acc', 'train_time',
                'test_loss', 'test_top1_acc', 'test_time'
            ])

        for epoch in range(args.start_epoch, args.epochs):
            if args.distributed:
                train_sampler.set_epoch(epoch)

            if (args.alternate_update):
                if epoch % 2 == 1:
                    optimizer.param_groups[1]['lr'] = 0
                    optimizer.param_groups[2]['lr'] = optimizer.param_groups[
                        0]['lr']
                else:
                    optimizer.param_groups[1]['lr'] = optimizer.param_groups[
                        0]['lr']
                    optimizer.param_groups[2]['lr'] = 0

            # train for one epoch
            print("current lr ",
                  [param['lr'] for param in optimizer.param_groups])
            train_epoch_log = train(train_loader, model, criterion, optimizer,
                                    epoch, args)
            if (args.lr_schedule):
                lr_scheduler.step()

            # evaluate on validation set
            val_epoch_log = validate(val_loader, model, criterion, args)
            acc1 = val_epoch_log[2]

            # append to log
            with open(os.path.join(model_dir, "train_log.csv"),
                      "a") as train_log_file:
                train_log_csv = csv.writer(train_log_file)
                train_log_csv.writerow(
                    ((epoch, ) + train_epoch_log + val_epoch_log))

            # remember best acc@1 and save checkpoint
            is_best = acc1 > best_acc1
            best_acc1 = max(acc1, best_acc1)

            if (args.print_weights):
                with open(
                        os.path.join(model_dir,
                                     'weights_log_' + str(epoch) + '.txt'),
                        'w') as weights_log_file:
                    with redirect_stdout(weights_log_file):
                        # Log model's state_dict
                        print("Model's state_dict:")
                        # TODO: Use checkpoint above
                        for param_tensor in model.state_dict():
                            print(param_tensor, "\t",
                                  model.state_dict()[param_tensor].size())
                            print(model.state_dict()[param_tensor])
                            print("")

            if not args.multiprocessing_distributed or (
                    args.multiprocessing_distributed
                    and args.rank % ngpus_per_node == 0):
                if is_best:
                    try:
                        if (args.save_model):
                            model_rounded = round_shift_weights(model,
                                                                clone=True)

                            torch.save(model_rounded.state_dict(),
                                       os.path.join(model_dir, "weights.pth"))
                            torch.save(model_rounded,
                                       os.path.join(model_dir, "model.pth"))
                    except:
                        print("WARNING: Unable to save model.pth")

                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'best_acc1': best_acc1,
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler,
                    }, is_best, model_dir)

    end_time = time.time()
    print("Total Time:", end_time - start_time)

    if (args.print_weights):
        if (model_rounded is None):
            model_rounded = round_shift_weights(model, clone=True)

        with open(os.path.join(model_dir, 'weights_log.txt'),
                  'w') as weights_log_file:
            with redirect_stdout(weights_log_file):
                # Log model's state_dict
                print("Model's state_dict:")
                # TODO: Use checkpoint above
                for param_tensor in model_rounded.state_dict():
                    print(param_tensor, "\t",
                          model_rounded.state_dict()[param_tensor].size())
                    print(model_rounded.state_dict()[param_tensor])
                    print("")
示例#5
0
def train(train_loader, model, criterion, optimizer, epoch, args):
    batch_time = AverageMeter('Time')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(train_loader), [batch_time],
                             args.rank,
                             prefix="Net: {} world: {} batch:{} on {}".format(
                                 args.arch, args.world_size, args.batch_size,
                                 args.tag))

    # switch to train mode
    model.train()

    end = time.time()

    model_parameters = [x for x in model.parameters() if x.requires_grad]
    parameters = [np.prod(p.size()) for p in model_parameters]
    params = sum(parameters)
    max_params = max(parameters) * 4
    min_params = min(parameters) * 4
    #print(model_parameters)

    print("copy below for layer sizes")
    #print([4*x for x in parameters])
    #print("detected model size: %d MB. average = %s B. max = %s B. min = %s Bcnt = %d\n" %
    #      (params * 4 / 1024 / 1024, params * 4 / len(model_parameters), max_params, min_params, len(model_parameters)))

    if 'inception' in args.arch or 'googlenet' in args.arch:
        image_size = (3, 299, 299)
        #overwrite transform
        pass
    else:
        image_size = (3, 224, 224)
        pass

    acc_forward = 0
    acc_backward = 0
    if args.so_no_backward:
        print("warning: backward pass is turned off. benchmark only")
        pass
    iteration = 0
    for i, (images, target) in enumerate(train_loader):
        #print("actual loaded batch = %d" % len(images))
        # measure data loading time
        # intercept the loop
        # if i == 0:
        if args.so_layer_info:
            _, params1, backward_ts = summary_string(model,
                                                     input_size=image_size,
                                                     bucketize=False)
            if sum(params1) != params * 4:
                print(sum(params1))
                print(params * 4)
                assert sum(params1) == params * 4
            print(params1, flush=True)
            print(backward_ts, flush=True)
            pass

        #print(target)
        for i in range(100000000) if args.data == None else range(1):
            data_time.update(time.time() - end)
            #fws = time.time_ns()
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
                pass
            target = target.cuda(args.gpu, non_blocking=True)

            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            #fwe = time.time_ns()
            #acc_forward += fwe - fws
            if args.data != None:
                top1.update(acc1[0], images.size(0))
                top5.update(acc5[0], images.size(0))
                pass

            # compute gradient and do SGD step
            #bws = time.time_ns()
            if args.so_no_backward == False:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                pass
            #fwe = time.time_ns()
            #acc_forward += fwe - fws

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            #acc_backward += bwe - bws
            # print(i)
            if iteration % args.print_freq == 0 and iteration > 0:
                if args.rank == 0:
                    progress.display(iteration)
                    if args.data != None:
                        print(top5)
                        print(top1)
                        pass
                    pass
                if args.so_one_shot:
                    return

                #print("[%.2f, %.2f]" % (acc_forward / args.print_freq / 1000000.0,
                #                        acc_backward / args.print_freq/1000000.0), flush=True)
                #acc_forward = 0
                #acc_backward = 0
                pass
            iteration += 1
            pass
        pass
示例#6
0
def model_summary(model_type,
                  img_res,
                  hidden_size,
                  enc_type,
                  dec_type,
                  loss,
                  batch_size,
                  device=torch.device("cuda:1"),
                  verbose=True):
    pattern = re.compile(r"Params size \(MB\):(.*)\n")
    pattern2 = re.compile(r"Forward/backward pass size \(MB\):(.*)\n")
    input_dim = 3
    enc_input_size = (input_dim, img_res, img_res)
    dec_input_size = (hidden_size, img_res // 4, img_res // 4)
    pdb.set_trace()
    if verbose:
        print(f"model:{model_type}")
        print(f"depth:{enc_type}_{dec_type}")

    if model_type == "acai":
        model = ACAI(img_res, input_dim, hidden_size, enc_type,
                     dec_type).to(device)
    elif model_type == "vqvae":
        model = VectorQuantizedVAE(input_dim,
                                   hidden_size,
                                   enc_type=enc_type,
                                   dec_type=dec_type).to(device)
    elif model_type == "vae":
        model = VAE(input_dim,
                    hidden_size,
                    enc_type=enc_type,
                    dec_type=dec_type).to(device)

    encoder_summary, _ = torchsummary.summary_string(model.encoder,
                                                     enc_input_size,
                                                     device=device,
                                                     batch_size=batch_size)
    decoder_summary, _ = torchsummary.summary_string(model.decoder,
                                                     dec_input_size,
                                                     device=device,
                                                     batch_size=batch_size)
    if verbose:
        print(encoder_summary)
        print(decoder_summary)

    discriminators = {}

    if model_type == "acai":
        disc = Discriminator(input_dim, img_res, "image").to(device)

        disc_summary, _ = torchsummary.summary_string(disc,
                                                      enc_input_size,
                                                      device=device,
                                                      batch_size=batch_size)
        disc_param_size = float(re.search(pattern, disc_summary).group(1))
        disc_forward_size = float(re.search(pattern2, disc_summary).group(1))
        discriminators["interp_disc"] = (disc_param_size, disc_forward_size)
    if loss == "gan":
        disc = Discriminator(input_dim, img_res, "image").to(device)

        disc_summary, _ = torchsummary.summary_string(disc,
                                                      enc_input_size,
                                                      device=device,
                                                      batch_size=batch_size)
        disc_param_size = float(re.search(pattern, disc_summary).group(1))
        disc_forward_size = float(re.search(pattern2, disc_summary).group(1))
        discriminators["recons_disc"] = (disc_param_size,
                                         2 * disc_forward_size)
    elif loss == "comp":
        disc = AnchorComparator(input_dim * 2, img_res, "image").to(device)

        disc_summary, _ = torchsummary.summary_string(disc,
                                                      enc_input_size,
                                                      device=device,
                                                      batch_size=batch_size)
        disc_param_size = float(re.search(pattern, disc_summary).group(1))
        disc_forward_size = float(re.search(pattern2, disc_summary).group(1))
        discriminators["recons_disc"] = (disc_param_size,
                                         2 * disc_forward_size)
    elif "comp_2" in loss:
        disc = ClubbedPermutationComparator(input_dim * 2, img_res,
                                            "image").to(device)

        disc_summary, _ = torchsummary.summary_string(disc,
                                                      enc_input_size,
                                                      device=device,
                                                      batch_size=batch_size)
        disc_param_size = float(re.search(pattern, disc_summary).group(1))
        disc_forward_size = float(re.search(pattern2, disc_summary).group(1))
        discriminators["recons_disc"] = (disc_param_size,
                                         2 * disc_forward_size)
    elif "comp_6" in loss:
        disc = FullPermutationComparator(input_dim * 2, img_res,
                                         "image").to(device)

        disc_summary, _ = torchsummary.summary_string(disc,
                                                      enc_input_size,
                                                      device=device,
                                                      batch_size=batch_size)
        disc_param_size = float(re.search(pattern, disc_summary).group(1))
        disc_forward_size = float(re.search(pattern2, disc_summary).group(1))
        discriminators["recons_disc"] = (disc_param_size,
                                         2 * disc_forward_size)

    encoder_param_size = float(re.search(pattern, encoder_summary).group(1))
    encoder_forward_size = float(re.search(pattern2, encoder_summary).group(1))
    decoder_param_size = float(re.search(pattern, decoder_summary).group(1))
    decoder_forward_size = float(re.search(pattern2, decoder_summary).group(1))

    if verbose:
        if "ACAI" in str(type(model)):
            print(
                f"discriminator:\n\tparams:{disc_param_size}\n\tforward:{disc_forward_size}"
            )

        if loss == "gan":
            print(
                f"reconstruction discriminator:\n\tparams:{disc_param_size}\n\tforward:{disc_forward_size}"
            )

        print(
            f"encoder:\n\tparams:{encoder_param_size}\n\tforward:{encoder_forward_size}"
        )
        print(
            f"decoder:\n\tparams:{decoder_param_size}\n\tforward:{decoder_forward_size}"
        )

    encoder = {"params": encoder_param_size, "forward": encoder_forward_size}
    decoder = {"params": decoder_param_size, "forward": decoder_forward_size}

    return encoder, decoder, discriminators
示例#7
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--type',
                        default='linear',
                        choices=['linear', 'conv'],
                        help='model architecture type: ' +
                        ' | '.join(['linear', 'conv']) + ' (default: linear)')
    parser.add_argument(
        '--model',
        default='',
        type=str,
        metavar='MODEL_PATH',
        help=
        'path to model file to load both its architecture and weights (default: none)'
    )
    parser.add_argument(
        '--weights',
        default='',
        type=str,
        metavar='WEIGHTS_PATH',
        help='path to file to load its weights (default: none)')
    parser.add_argument('--shift-depth',
                        type=int,
                        default=0,
                        help='how many layers to convert to shift')
    parser.add_argument(
        '-st',
        '--shift-type',
        default='PS',
        choices=['Q', 'PS'],
        help=
        'type of DeepShift method for training and representing weights (default: PS)'
    )
    parser.add_argument('-r',
                        '--rounding',
                        default='deterministic',
                        choices=['deterministic', 'stochastic'],
                        help='type of rounding (default: deterministic)')
    parser.add_argument('-wb',
                        '--weight-bits',
                        type=int,
                        default=5,
                        help='number of bits to represent the shift weights')
    parser.add_argument(
        '-ab',
        '--activation-bits',
        nargs='+',
        default=[16, 16],
        help=
        'number of integer and fraction bits to represent activation (fixed point format)'
    )
    parser.add_argument('-j',
                        '--workers',
                        default=1,
                        type=int,
                        metavar='N',
                        help='number of data loading workers (default: 1)')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('-opt',
                        '--optimizer',
                        metavar='OPT',
                        default="SGD",
                        help='optimizer algorithm')
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.0,
                        metavar='M',
                        help='SGD momentum (default: 0.0)')
    parser.add_argument('--resume',
                        default='',
                        type=str,
                        metavar='CHECKPOINT_PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('-e',
                        '--evaluate',
                        dest='evaluate',
                        action='store_true',
                        help='only evaluate model on validation set')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--pretrained',
                        dest='pretrained',
                        default=False,
                        type=lambda x: bool(distutils.util.strtobool(x)),
                        help='use pre-trained model of full conv or fc model')

    parser.add_argument('--save-model',
                        default=True,
                        type=lambda x: bool(distutils.util.strtobool(x)),
                        help='For Saving the current Model (default: True)')
    parser.add_argument(
        '--print-weights',
        default=True,
        type=lambda x: bool(distutils.util.strtobool(x)),
        help='For printing the weights of Model (default: True)')
    parser.add_argument('--desc',
                        type=str,
                        default=None,
                        help='description to append to model directory name')
    parser.add_argument('--use-kernel',
                        type=lambda x: bool(distutils.util.strtobool(x)),
                        default=False,
                        help='whether using custom shift kernel')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    assert len(
        args.activation_bits
    ) == 2, "activation-bits argument needs to be a tuple of 2 values representing number of integer bits and number of fraction bits, e.g., '3 5' for 8-bits fixed point or '3 13' for 16-bits fixed point"
    [args.activation_integer_bits,
     args.activation_fraction_bits] = args.activation_bits
    [args.activation_integer_bits, args.activation_fraction_bits] = [
        int(args.activation_integer_bits),
        int(args.activation_fraction_bits)
    ]

    if (args.evaluate is False and args.use_kernel is True):
        raise ValueError(
            'Our custom kernel currently supports inference only, not training.'
        )

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {
        'num_workers': args.workers,
        'pin_memory': True
    } if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            '../data',
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.1307, ),
                    (0.3081, ))  # transforms.Normalize((0,), (255,))
            ])),
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            '../data',
            train=False,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.1307, ),
                    (0.3081, ))  # transforms.Normalize((0,), (255,))
            ])),
        batch_size=args.test_batch_size,
        shuffle=True,
        **kwargs)

    if args.model:
        if args.type or args.pretrained:
            print(
                "WARNING: Ignoring arguments \"type\" and \"pretrained\" when creating model..."
            )
        model = None
        saved_checkpoint = torch.load(args.model)
        if isinstance(saved_checkpoint, nn.Module):
            model = saved_checkpoint
        elif "model" in saved_checkpoint:
            model = saved_checkpoint["model"]
        else:
            raise Exception("Unable to load model from " + args.model)
    else:
        if args.type == 'linear':
            model = LinearMNIST().to(device)
        elif args.type == 'conv':
            model = ConvMNIST().to(device)

        if args.pretrained:
            model.load_state_dict(
                torch.load("./models/mnist/simple_" + args.type +
                           "/shift_0/weights.pth"))
            model = model.to(device)

    model_rounded = None

    if args.weights:
        saved_weights = torch.load(args.weights)
        if isinstance(saved_weights, nn.Module):
            state_dict = saved_weights.state_dict()
        elif "state_dict" in saved_weights:
            state_dict = saved_weights["state_dict"]
        else:
            state_dict = saved_weights

        model.load_state_dict(state_dict)

    if args.shift_depth > 0:
        model, _ = convert_to_shift(
            model,
            args.shift_depth,
            args.shift_type,
            convert_all_linear=(args.type != 'linear'),
            convert_weights=True,
            use_kernel=args.use_kernel,
            use_cuda=use_cuda,
            rounding=args.rounding,
            weight_bits=args.weight_bits,
            act_integer_bits=args.activation_integer_bits,
            act_fraction_bits=args.activation_fraction_bits)
        model = model.to(device)
    elif args.use_kernel and args.shift_depth == 0:
        model = convert_to_unoptimized(model)
        model = model.to(device)
    elif args.use_kernel and args.shift_depth == 0:
        model = convert_to_unoptimized(model)
        model = model.to(device)

    loss_fn = F.cross_entropy  # F.nll_loss
    # define optimizer
    optimizer = None
    if (args.optimizer.lower() == "sgd"):
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum)
    elif (args.optimizer.lower() == "adadelta"):
        optimizer = torch.optim.Adadelta(model.parameters(), args.lr)
    elif (args.optimizer.lower() == "adagrad"):
        optimizer = torch.optim.Adagrad(model.parameters(), args.lr)
    elif (args.optimizer.lower() == "adam"):
        optimizer = torch.optim.Adam(model.parameters(), args.lr)
    elif (args.optimizer.lower() == "rmsprop"):
        optimizer = torch.optim.RMSprop(model.parameters(), args.lr)
    elif (args.optimizer.lower() == "radam"):
        optimizer = optim.RAdam(model.parameters(), args.lr)
    elif (args.optimizer.lower() == "ranger"):
        optimizer = optim.Ranger(model.parameters(), args.lr)
    else:
        raise ValueError("Optimizer type: ", args.optimizer,
                         " is not supported or known")

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            if 'state_dict' in checkpoint:
                model.load_state_dict(checkpoint['state_dict'])
            else:
                model.load_state_dict(checkpoint)
            print("=> loaded checkpoint '{}'".format(args.resume))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # name model sub-directory "shift_all" if all layers are converted to shift layers
    conv2d_layers_count = count_layer_type(
        model, nn.Conv2d) + count_layer_type(model,
                                             unoptimized.UnoptimizedConv2d)
    linear_layers_count = count_layer_type(
        model, nn.Linear) + count_layer_type(model,
                                             unoptimized.UnoptimizedLinear)
    if (args.shift_depth > 0):
        if (args.shift_type == 'Q'):
            shift_label = "shift_q"
        else:
            shift_label = "shift_ps"
    else:
        shift_label = "shift"

    # name model sub-directory "shift_all" if all layers are converted to shift layers
    conv2d_layers_count = count_layer_type(model, nn.Conv2d)
    linear_layers_count = count_layer_type(model, nn.Linear)
    if (conv2d_layers_count == 0 and linear_layers_count == 0):
        shift_label += "_all"
    else:
        shift_label += "_%s" % (args.shift_depth)

    if (args.shift_depth > 0):
        shift_label += "_wb_%s" % (args.weight_bits)

    if (args.desc is not None and len(args.desc) > 0):
        desc_label = "_%s" % (args.desc)
    else:
        desc_label = ""

    model_name = 'simple_%s/%s%s' % (args.type, shift_label, desc_label)

    # if evaluating round weights to ensure that the results are due to powers of 2 weights
    if (args.evaluate):
        model = round_shift_weights(model)

    model_summary = None
    try:
        model_summary, model_params_info = torchsummary.summary_string(
            model, input_size=(1, 28, 28))
        print(model_summary)
        print(
            "WARNING: The summary function reports duplicate parameters for multi-GPU case"
        )
    except:
        print("WARNING: Unable to obtain summary of model")

    model_dir = os.path.join(
        os.path.join(os.path.join(os.getcwd(), "models"), "mnist"), model_name)
    if not os.path.isdir(model_dir):
        os.makedirs(model_dir, exist_ok=True)

    if (args.save_model):
        with open(os.path.join(model_dir, 'command_args.txt'),
                  'w') as command_args_file:
            for arg, value in sorted(vars(args).items()):
                command_args_file.write(arg + ": " + str(value) + "\n")

        with open(os.path.join(model_dir, 'model_summary.txt'),
                  'w') as summary_file:
            with redirect_stdout(summary_file):
                if (model_summary is not None):
                    print(model_summary)
                    print(
                        "WARNING: The summary function reports duplicate parameters for multi-GPU case"
                    )
                else:
                    print("WARNING: Unable to obtain summary of model")

    # del model_tmp_copy

    start_time = time.time()
    if args.evaluate:
        test_loss, correct = test(args, model, device, test_loader, loss_fn)
        test_log = [(test_loss, correct / 1e4)]
        with open(os.path.join(model_dir, "test_log.csv"),
                  "w") as test_log_file:
            test_log_csv = csv.writer(test_log_file)
            test_log_csv.writerow(['test_loss', 'correct'])
            test_log_csv.writerows(test_log)
    else:
        train_log = []
        for epoch in range(1, args.epochs + 1):
            train_loss = train(args, model, device, train_loader, loss_fn,
                               optimizer, epoch)
            test_loss, correct = test(args, model, device, test_loader,
                                      loss_fn)

            if (args.print_weights):
                with open(
                        os.path.join(model_dir,
                                     'weights_log_' + str(epoch) + '.txt'),
                        'w') as weights_log_file:
                    with redirect_stdout(weights_log_file):
                        # Log model's state_dict
                        print("Model's state_dict:")
                        # TODO: Use checkpoint above
                        for param_tensor in model.state_dict():
                            print(param_tensor, "\t",
                                  model.state_dict()[param_tensor].size())
                            print(model.state_dict()[param_tensor])
                            print("")

            train_log.append((epoch, train_loss, test_loss, correct / 1e4))

        with open(os.path.join(model_dir, "train_log.csv"),
                  "w") as train_log_file:
            train_log_csv = csv.writer(train_log_file)
            train_log_csv.writerow(
                ['epoch', 'train_loss', 'test_loss', 'test_accuracy'])
            train_log_csv.writerows(train_log)

        if (args.save_model):
            model_rounded = round_shift_weights(model, clone=True)

            torch.save(model_rounded, os.path.join(model_dir, "model.pth"))
            torch.save(model_rounded.state_dict(),
                       os.path.join(model_dir, "weights.pth"))

    end_time = time.time()
    print("Total Time:", end_time - start_time)

    if (args.print_weights):
        if (model_rounded is None):
            model_rounded = round_shift_weights(model, clone=True)

        with open(os.path.join(model_dir, 'weights_log.txt'),
                  'w') as weights_log_file:
            with redirect_stdout(weights_log_file):
                # Log model's state_dict
                print("Model's state_dict:")
                # TODO: Use checkpoint above
                for param_tensor in model_rounded.state_dict():
                    print(param_tensor, "\t",
                          model_rounded.state_dict()[param_tensor].size())
                    print(model_rounded.state_dict()[param_tensor])
                    print("")
示例#8
0
    torch.manual_seed(43)

    test_loader = DataLoader(test_dataset,
                             batch_size,
                             num_workers=4,
                             pin_memory=True)

    if (job["model"] == "resnet50"):
        model = models.resnet50(pretrained=True)
    elif (job["model"] == "resnet18"):
        model = models.resnet18(pretrained=True)
    elif (job["model"] == "squeezenet"):
        model = models.squeezenet1_0(pretrained=True)
    else:
        model = models.googlenet(pretrained=True)

    model.to(device)
    print(tuple(test_loader.dataset[0][0].shape))
    # quit()
    job["size"] = float(
        summary_string(model,
                       batch_size=job["batchSize"],
                       input_size=tuple(test_loader.dataset[0][0].shape))[2])
    job["id"] = str(uuid.uuid1())

    results.append(job)

with open("jobs.json", "w") as f:
    json.dump(results, f)