Пример #1
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)
    # create model
    if args.model:
        if args.arch or args.pretrained:
            print(
                "WARNING: Ignoring arguments \"arch\" and \"pretrained\" when creating model..."
            )
        model = None
        saved_checkpoint = torch.load(args.model)
        if isinstance(saved_checkpoint, nn.Module):
            model = saved_checkpoint
        elif "model" in saved_checkpoint:
            model = saved_checkpoint["model"]
        else:
            raise Exception("Unable to load model from " + args.model)

        if (args.gpu is not None):
            model.cuda(args.gpu)
    elif args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    model_rounded = None

    if args.weights:
        saved_weights = torch.load(args.weights)
        if isinstance(saved_weights, nn.Module):
            state_dict = saved_weights.state_dict()
        elif "state_dict" in saved_weights:
            state_dict = saved_weights["state_dict"]
        else:
            state_dict = saved_weights

        try:
            model.load_state_dict(state_dict)
        except:
            # create new OrderedDict that does not contain module.
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[7:]  # remove module.
                new_state_dict[name] = v

            model.load_state_dict(new_state_dict)

    if args.shift_depth > 0:
        model, _ = convert_to_shift(model,
                                    args.shift_depth,
                                    args.shift_type,
                                    convert_weights=args.pretrained
                                    or args.weights,
                                    freeze_sign=(args.lr_sign == 0),
                                    use_kernel=args.use_kernel)
    elif args.use_kernel and args.shift_depth == 0:
        model = convert_to_unoptimized(model)

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(args.workers / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    # create optimizer
    model_other_params = []
    model_sign_params = []
    model_shift_params = []

    for name, param in model.named_parameters():
        if (name.endswith(".sign")):
            model_sign_params.append(param)
        elif (name.endswith(".shift")):
            model_shift_params.append(param)
        else:
            model_other_params.append(param)

    params_dict = [{
        "params": model_other_params
    }, {
        "params": model_sign_params,
        'lr': args.lr_sign if args.lr_sign is not None else args.lr,
        'weight_decay': 0
    }, {
        "params": model_shift_params,
        'lr': args.lr,
        'weight_decay': 0
    }]

    # define optimizer
    optimizer = None
    if (args.optimizer.lower() == "sgd"):
        optimizer = torch.optim.SGD(params_dict,
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    elif (args.optimizer.lower() == "adadelta"):
        optimizer = torch.optim.Adadelta(params_dict,
                                         args.lr,
                                         weight_decay=args.weight_decay)
    elif (args.optimizer.lower() == "adagrad"):
        optimizer = torch.optim.Adagrad(params_dict,
                                        args.lr,
                                        weight_decay=args.weight_decay)
    elif (args.optimizer.lower() == "adam"):
        optimizer = torch.optim.Adam(params_dict,
                                     args.lr,
                                     weight_decay=args.weight_decay)
    elif (args.optimizer.lower() == "rmsprop"):
        optimizer = torch.optim.RMSprop(params_dict,
                                        args.lr,
                                        weight_decay=args.weight_decay)
    elif (args.optimizer.lower() == "radam"):
        optimizer = optim.RAdam(params_dict,
                                args.lr,
                                weight_decay=args.weight_decay)
    elif (args.optimizer.lower() == "ranger"):
        optimizer = optim.Ranger(params_dict,
                                 args.lr,
                                 weight_decay=args.weight_decay)
    else:
        raise ValueError("Optimizer type: ", args.optimizer,
                         " is not supported or known")

    lr_scheduler = None
    if args.opt_ckpt:
        print(
            "WARNING: Ignoring arguments \"lr\", \"momentum\", \"weight_decay\", and \"lr_schedule\""
        )

        opt_ckpt = torch.load(args.opt_ckpt)
        if 'optimizer' in opt_ckpt:
            opt_ckpt = opt_ckpt['optimizer']
        optimizer.load_state_dict(opt_ckpt)

        if 'lr_scheduler' in opt_ckpt:
            lr_scheduler = opt_ckpt['lr_scheduler']

    if (args.lr_schedule and lr_scheduler is not None):
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[100, 150], last_epoch=args.start_epoch - 1)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            try:
                model.load_state_dict(checkpoint['state_dict'])
            except:
                # create new OrderedDict that does not contain module.
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    if args.arch.startswith('alexnet') or args.arch.startswith(
                            'vgg'):
                        if (k.startswith("features")):
                            name = k[0:9] + k[
                                9 + 7:]  # remove "module" after features
                        else:
                            name = k
                    else:
                        name = k[7:]  # remove "module" at beginning of name
                    new_state_dict[name] = v

                # load params
                model.load_state_dict(new_state_dict)
            optimizer.load_state_dict(checkpoint['optimizer'])
            if 'lr_scheduler' in checkpoint and checkpoint[
                    'lr_scheduler'] is not None:
                lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # if evaluating round weights to ensure that the results are due to powers of 2 weights
    if (args.evaluate):
        model = round_shift_weights(model)

    cudnn.benchmark = True

    # model_tmp_copy = copy.deepcopy(model) # we noticed calling summary() on original model degrades it's accuracy. So we will call summary() on a copy of the model
    # try:
    #     summary(model_tmp_copy, input_size=(3, 224, 224))
    #     print("WARNING: The summary function reports duplicate parameters for multi-GPU case")
    # except:
    #     print("WARNING: Unable to obtain summary of model")

    # name model sub-directory "shift_all" if all layers are converted to shift layers
    conv2d_layers_count = count_layer_type(model, nn.Conv2d)
    linear_layers_count = count_layer_type(model, nn.Linear)
    if (args.shift_type == 'Q'):
        shift_label = "shift_q"
    else:
        shift_label = "shift"

    if (conv2d_layers_count == 0 and linear_layers_count == 0):
        shift_label += "_all"
    else:
        shift_label += "_%s" % (args.shift_depth)

    if args.desc is not None and len(args.desc) > 0:
        model_name = '%s/%s_%s' % (args.arch, args.desc, shift_label)
    else:
        model_name = '%s/%s' % (args.arch, shift_label)

    if (args.save_model):
        model_dir = os.path.join(
            os.path.join(os.path.join(os.getcwd(), "models"), "imagenet"),
            model_name)
        if not os.path.isdir(model_dir):
            os.makedirs(model_dir, exist_ok=True)

        with open(os.path.join(model_dir, 'command_args.txt'),
                  'w') as command_args_file:
            for arg, value in sorted(vars(args).items()):
                command_args_file.write(arg + ": " + str(value) + "\n")

    #     with open(os.path.join(model_dir, 'model_summary.txt'), 'w') as summary_file:
    #         with redirect_stdout(summary_file):
    #             try:
    #                 # TODO: make this summary function deal with parameters that are not named "weight" and "bias"
    #                 summary(model_tmp_copy, input_size=(3, 224, 224))
    #                 print("WARNING: The summary function reports duplicate parameters for multi-GPU case")
    #             except:
    #                 print("WARNING: Unable to obtain summary of model")

    # del model_tmp_copy # to save memory

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    start_time = time.time()

    if args.evaluate:
        start_log_time = time.time()
        val_log = validate(val_loader, model, criterion, args)
        val_log = [val_log]

        with open(os.path.join(model_dir, "test_log.csv"),
                  "w") as test_log_file:
            test_log_csv = csv.writer(test_log_file)
            test_log_csv.writerow([
                'test_loss', 'test_top1_acc', 'test_top5_acc', 'test_time',
                'cumulative_time'
            ])
            test_log_csv.writerows(val_log +
                                   [(time.time() - start_log_time, )])
    else:
        train_log = []

        with open(os.path.join(model_dir, "train_log.csv"),
                  "w") as train_log_file:
            train_log_csv = csv.writer(train_log_file)
            train_log_csv.writerow([
                'epoch', 'train_loss', 'train_top1_acc', 'train_top5_acc',
                'train_time', 'test_loss', 'test_top1_acc', 'test_top5_acc',
                'test_time', 'cumulative_time'
            ])

        start_log_time = time.time()
        for epoch in range(args.start_epoch, args.epochs):
            if args.distributed:
                train_sampler.set_epoch(epoch)
            adjust_learning_rate(optimizer, epoch, args)

            # train for one epoch
            print("current lr ",
                  [param['lr'] for param in optimizer.param_groups])
            train_epoch_log = train(train_loader, model, criterion, optimizer,
                                    epoch, args)
            if (args.lr_schedule):
                lr_scheduler.step()

            # evaluate on validation set
            val_epoch_log = validate(val_loader, model, criterion, args)
            acc1 = val_epoch_log[2]

            # append to log
            with open(os.path.join(model_dir, "train_log.csv"),
                      "a") as train_log_file:
                train_log_csv = csv.writer(train_log_file)
                train_log_csv.writerow(
                    ((epoch, ) + train_epoch_log + val_epoch_log +
                     (time.time() - start_log_time, )))

            # remember best acc@1 and save checkpoint
            is_best = acc1 > best_acc1
            best_acc1 = max(acc1, best_acc1)

            if (args.print_weights):
                with open(
                        os.path.join(model_dir,
                                     'weights_log_' + str(epoch) + '.txt'),
                        'w') as weights_log_file:
                    with redirect_stdout(weights_log_file):
                        # Log model's state_dict
                        print("Model's state_dict:")
                        # TODO: Use checkpoint above
                        for param_tensor in model.state_dict():
                            print(param_tensor, "\t",
                                  model.state_dict()[param_tensor].size())
                            print(model.state_dict()[param_tensor])
                            print("")

            if not args.multiprocessing_distributed or (
                    args.multiprocessing_distributed
                    and args.rank % ngpus_per_node == 0):
                if is_best:
                    try:
                        if (args.save_model):
                            model_rounded = round_shift_weights(model,
                                                                clone=True)

                            torch.save(model_rounded.state_dict(),
                                       os.path.join(model_dir, "weights.pth"))
                            torch.save(model_rounded,
                                       os.path.join(model_dir, "model.pth"))
                    except:
                        print("WARNING: Unable to save model.pth")

                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'best_acc1': best_acc1,
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler,
                    }, is_best, model_dir)

    end_time = time.time()
    print("Total Time:", end_time - start_time)

    if (args.print_weights):
        if (model_rounded is None):
            model_rounded = round_shift_weights(model, clone=True)

        with open(os.path.join(model_dir, 'weights_log.txt'),
                  'w') as weights_log_file:
            with redirect_stdout(weights_log_file):
                # Log model's state_dict
                print("Model's state_dict:")
                # TODO: Use checkpoint above
                for param_tensor in model_rounded.state_dict():
                    print(param_tensor, "\t",
                          model_rounded.state_dict()[param_tensor].size())
                    print(model_rounded.state_dict()[param_tensor])
                    print("")
Пример #2
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--type',
                        default='linear',
                        choices=['linear', 'conv'],
                        help='model architecture type: ' +
                        ' | '.join(['linear', 'conv']) + ' (default: linear)')
    parser.add_argument(
        '--model',
        default='',
        type=str,
        metavar='MODEL_PATH',
        help=
        'path to model file to load both its architecture and weights (default: none)'
    )
    parser.add_argument(
        '--weights',
        default='',
        type=str,
        metavar='WEIGHTS_PATH',
        help='path to file to load its weights (default: none)')
    parser.add_argument('--shift-depth',
                        type=int,
                        default=0,
                        help='how many layers to convert to shift')
    parser.add_argument(
        '-st',
        '--shift-type',
        default='PS',
        choices=['Q', 'PS'],
        help=
        'type of DeepShift method for training and representing weights (default: PS)'
    )
    parser.add_argument('-j',
                        '--workers',
                        default=1,
                        type=int,
                        metavar='N',
                        help='number of data loading workers (default: 1)')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('-opt',
                        '--optimizer',
                        metavar='OPT',
                        default="SGD",
                        help='optimizer algorithm')
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.0,
                        metavar='M',
                        help='SGD momentum (default: 0.0)')
    parser.add_argument('--resume',
                        default='',
                        type=str,
                        metavar='CHECKPOINT_PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('-e',
                        '--evaluate',
                        dest='evaluate',
                        action='store_true',
                        help='only evaluate model on validation set')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--pretrained',
                        dest='pretrained',
                        default=False,
                        type=lambda x: bool(distutils.util.strtobool(x)),
                        help='use pre-trained model of full conv or fc model')

    parser.add_argument('--save-model',
                        default=True,
                        type=lambda x: bool(distutils.util.strtobool(x)),
                        help='For Saving the current Model (default: True)')
    parser.add_argument(
        '--print-weights',
        default=True,
        type=lambda x: bool(distutils.util.strtobool(x)),
        help='For printing the weights of Model (default: True)')
    parser.add_argument('--desc',
                        type=str,
                        default=None,
                        help='description to append to model directory name')
    parser.add_argument('--use-kernel',
                        type=lambda x: bool(distutils.util.strtobool(x)),
                        default=False,
                        help='whether using custom shift kernel')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    if (args.evaluate is False and args.use_kernel is True):
        raise ValueError(
            'Our custom kernel currently supports inference only, not training.'
        )

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {
        'num_workers': args.workers,
        'pin_memory': True
    } if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            '../data',
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.1307, ),
                    (0.3081, ))  # transforms.Normalize((0,), (255,))
            ])),
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            '../data',
            train=False,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.1307, ),
                    (0.3081, ))  # transforms.Normalize((0,), (255,))
            ])),
        batch_size=args.test_batch_size,
        shuffle=True,
        **kwargs)

    if args.model:
        if args.type or args.pretrained:
            print(
                "WARNING: Ignoring arguments \"type\" and \"pretrained\" when creating model..."
            )
        model = None
        saved_checkpoint = torch.load(args.model)
        if isinstance(saved_checkpoint, nn.Module):
            model = saved_checkpoint
        elif "model" in saved_checkpoint:
            model = saved_checkpoint["model"]
        else:
            raise Exception("Unable to load model from " + args.model)
    else:
        if args.type == 'linear':
            model = LinearMNIST().to(device)
        elif args.type == 'conv':
            model = ConvMNIST().to(device)

        if args.pretrained:
            model.load_state_dict(
                torch.load("./models/mnist/simple_" + args.type +
                           "/shift_0/weights.pth"))
            model = model.to(device)

    model_rounded = None

    if args.weights:
        saved_weights = torch.load(args.weights)
        if isinstance(saved_weights, nn.Module):
            state_dict = saved_weights.state_dict()
        elif "state_dict" in saved_weights:
            state_dict = saved_weights["state_dict"]
        else:
            state_dict = saved_weights

        model.load_state_dict(state_dict)

    if args.shift_depth > 0:
        model, _ = convert_to_shift(model,
                                    args.shift_depth,
                                    args.shift_type,
                                    convert_all_linear=(args.type != 'linear'),
                                    convert_weights=True,
                                    use_kernel=args.use_kernel,
                                    use_cuda=use_cuda)
        model = model.to(device)
    elif args.use_kernel and args.shift_depth == 0:
        model = convert_to_unoptimized(model)
        model = model.to(device)
    elif args.use_kernel and args.shift_depth == 0:
        model = convert_to_unoptimized(model)
        model = model.to(device)

    loss_fn = F.cross_entropy  # F.nll_loss
    # define optimizer
    optimizer = None
    if (args.optimizer.lower() == "sgd"):
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum)
    elif (args.optimizer.lower() == "adadelta"):
        optimizer = torch.optim.Adadelta(model.parameters(), args.lr)
    elif (args.optimizer.lower() == "adagrad"):
        optimizer = torch.optim.Adagrad(model.parameters(), args.lr)
    elif (args.optimizer.lower() == "adam"):
        optimizer = torch.optim.Adam(model.parameters(), args.lr)
    elif (args.optimizer.lower() == "rmsprop"):
        optimizer = torch.optim.RMSprop(model.parameters(), args.lr)
    elif (args.optimizer.lower() == "radam"):
        optimizer = optim.RAdam(model.parameters(), args.lr)
    elif (args.optimizer.lower() == "ranger"):
        optimizer = optim.Ranger(model.parameters(), args.lr)
    else:
        raise ValueError("Optimizer type: ", args.optimizer,
                         " is not supported or known")

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            if 'state_dict' in checkpoint:
                model.load_state_dict(checkpoint['state_dict'])
            else:
                model.load_state_dict(checkpoint)
            print("=> loaded checkpoint '{}'".format(args.resume))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if args.desc is not None and len(args.desc) > 0:
        model_name = 'simple_%s/%s_shift_%s' % (args.type, args.desc,
                                                args.shift_depth)
    else:
        model_name = 'simple_%s/shift_%s' % (args.type, args.shift_depth)

    # if evaluating round weights to ensure that the results are due to powers of 2 weights
    if (args.evaluate):
        model = round_shift_weights(model)

    model_tmp_copy = copy.deepcopy(
        model
    )  # we noticed calling summary() on original model degrades it's accuracy. So we will call summary() on a copy of the model
    try:
        summary(model_tmp_copy,
                input_size=(1, 28, 28),
                device=("cuda" if use_cuda else "cpu"))
        print(
            "WARNING: The summary function reports duplicate parameters for multi-GPU case"
        )
    except:
        print("WARNING: Unable to obtain summary of model")

    model_dir = os.path.join(
        os.path.join(os.path.join(os.getcwd(), "models"), "mnist"), model_name)
    if not os.path.isdir(model_dir):
        os.makedirs(model_dir, exist_ok=True)

    if (args.save_model):
        with open(os.path.join(model_dir, 'command_args.txt'),
                  'w') as command_args_file:
            for arg, value in sorted(vars(args).items()):
                command_args_file.write(arg + ": " + str(value) + "\n")

        with open(os.path.join(model_dir, 'model_summary.txt'),
                  'w') as summary_file:
            with redirect_stdout(summary_file):
                try:
                    summary(model_tmp_copy,
                            input_size=(1, 28, 28),
                            device=("cuda" if use_cuda else "cpu"))
                    print(
                        "WARNING: The summary function reports duplicate parameters for multi-GPU case"
                    )
                except:
                    print("WARNING: Unable to obtain summary of model")

    # del model_tmp_copy

    start_time = time.time()
    if args.evaluate:
        test_loss, correct = test(args, model, device, test_loader, loss_fn)
        test_log = [(test_loss, correct / 1e4)]
        with open(os.path.join(model_dir, "test_log.csv"),
                  "w") as test_log_file:
            test_log_csv = csv.writer(test_log_file)
            test_log_csv.writerow(['test_loss', 'correct'])
            test_log_csv.writerows(test_log)
    else:
        train_log = []
        for epoch in range(1, args.epochs + 1):
            train_loss = train(args, model, device, train_loader, loss_fn,
                               optimizer, epoch)
            test_loss, correct = test(args, model, device, test_loader,
                                      loss_fn)

            if (args.print_weights):
                with open(
                        os.path.join(model_dir,
                                     'weights_log_' + str(epoch) + '.txt'),
                        'w') as weights_log_file:
                    with redirect_stdout(weights_log_file):
                        # Log model's state_dict
                        print("Model's state_dict:")
                        # TODO: Use checkpoint above
                        for param_tensor in model.state_dict():
                            print(param_tensor, "\t",
                                  model.state_dict()[param_tensor].size())
                            print(model.state_dict()[param_tensor])
                            print("")

            train_log.append((epoch, train_loss, test_loss, correct / 1e4))

        with open(os.path.join(model_dir, "train_log.csv"),
                  "w") as train_log_file:
            train_log_csv = csv.writer(train_log_file)
            train_log_csv.writerow(
                ['epoch', 'train_loss', 'test_loss', 'test_accuracy'])
            train_log_csv.writerows(train_log)

        if (args.save_model):
            model_rounded = round_shift_weights(model, clone=True)

            torch.save(model_rounded, os.path.join(model_dir, "model.pth"))
            torch.save(model_rounded.state_dict(),
                       os.path.join(model_dir, "weights.pth"))

    end_time = time.time()
    print("Total Time:", end_time - start_time)

    if (args.print_weights):
        if (model_rounded is None):
            model_rounded = round_shift_weights(model, clone=True)

        with open(os.path.join(model_dir, 'weights_log.txt'),
                  'w') as weights_log_file:
            with redirect_stdout(weights_log_file):
                # Log model's state_dict
                print("Model's state_dict:")
                # TODO: Use checkpoint above
                for param_tensor in model_rounded.state_dict():
                    print(param_tensor, "\t",
                          model_rounded.state_dict()[param_tensor].size())
                    print(model_rounded.state_dict()[param_tensor])
                    print("")
Пример #3
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu
    num_classes = 10

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)
    # create model
    if args.model:
        if args.arch or args.pretrained:
            print(
                "WARNING: Ignoring arguments \"arch\" and \"pretrained\" when creating model..."
            )
        model = None
        saved_checkpoint = torch.load(args.model)
        if isinstance(saved_checkpoint, nn.Module):
            model = saved_checkpoint
        elif "model" in saved_checkpoint:
            model = saved_checkpoint["model"]
        else:
            raise Exception("Unable to load model from " + args.model)

        if (args.gpu is not None):
            model.cuda(args.gpu)
    elif args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    model_rounded = None

    #TODO: add option for finetune vs. feature extraction that only work if pretrained weights are imagenet
    if args.freeze and args.pretrained != "none":
        for param in model.parameters():
            param.requires_grad = False

    if args.weights:
        saved_weights = torch.load(args.weights)
        if isinstance(saved_weights, nn.Module):
            state_dict = saved_weights.state_dict()
        elif "state_dict" in saved_weights:
            state_dict = saved_weights["state_dict"]
        else:
            state_dict = saved_weights

        try:
            model.load_state_dict(state_dict)
        except:
            # create new OrderedDict that does not contain module.
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[7:]  # remove module.
                new_state_dict[name] = v

            model.load_state_dict(new_state_dict)

    if args.shift_depth > 0:
        model, _ = convert_to_shift(model,
                                    args.shift_depth,
                                    args.shift_type,
                                    convert_weights=(args.pretrained != "none"
                                                     or args.weights),
                                    use_kernel=args.use_kernel,
                                    rounding=args.rounding,
                                    weight_bits=args.weight_bits)
    elif args.use_kernel and args.shift_depth == 0:
        model = convert_to_unoptimized(model)

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(args.workers / ngpus_per_node)
            #TODO: Allow args.gpu to be a list of IDs
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if (args.arch.startswith('alexnet')):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion)
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    # create optimizer
    model_other_params = []
    model_sign_params = []
    model_shift_params = []

    for name, param in model.named_parameters():
        if (name.endswith(".sign")):
            model_sign_params.append(param)
        elif (name.endswith(".shift")):
            model_shift_params.append(param)
        else:
            model_other_params.append(param)

    params_dict = [{
        "params": model_other_params
    }, {
        "params": model_sign_params,
        'lr': args.lr_sign if args.lr_sign is not None else args.lr,
        'weight_decay': 0
    }, {
        "params": model_shift_params,
        'lr': args.lr,
        'weight_decay': 0
    }]

    # define optimizer
    optimizer = None
    if (args.optimizer.lower() == "sgd"):
        optimizer = torch.optim.SGD(params_dict,
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    elif (args.optimizer.lower() == "adadelta"):
        optimizer = torch.optim.Adadelta(params_dict,
                                         args.lr,
                                         weight_decay=args.weight_decay)
    elif (args.optimizer.lower() == "adagrad"):
        optimizer = torch.optim.Adagrad(params_dict,
                                        args.lr,
                                        weight_decay=args.weight_decay)
    elif (args.optimizer.lower() == "adam"):
        optimizer = torch.optim.Adam(params_dict,
                                     args.lr,
                                     weight_decay=args.weight_decay)
    elif (args.optimizer.lower() == "rmsprop"):
        optimizer = torch.optim.RMSprop(params_dict,
                                        args.lr,
                                        weight_decay=args.weight_decay)
    elif (args.optimizer.lower() == "radam"):
        optimizer = optim.RAdam(params_dict,
                                args.lr,
                                weight_decay=args.weight_decay)
    elif (args.optimizer.lower() == "ranger"):
        optimizer = optim.Ranger(params_dict,
                                 args.lr,
                                 weight_decay=args.weight_decay)
    else:
        raise ValueError("Optimizer type: ", args.optimizer,
                         " is not supported or known")

    # define learning rate schedule
    if (args.lr_schedule):
        if (args.lr_step_size is not None):
            lr_scheduler = torch.optim.lr_scheduler.StepLR(
                optimizer, step_size=args.lr_step_size)
        else:
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer,
                milestones=[80, 120, 160, 180],
                last_epoch=args.start_epoch - 1)

    if args.arch in ['resnet1202', 'resnet110']:
        # for resnet1202 original paper uses lr=0.01 for first 400 minibatches for warm-up
        # then switch back. In this implementation it will correspond for first epoch.
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.lr * 0.1

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # if evaluating round weights to ensure that the results are due to powers of 2 weights
    if (args.evaluate):
        model = round_shift_weights(model)

    cudnn.benchmark = True

    model_summary = None
    try:
        model_summary, model_params_info = torchsummary.summary_string(
            model, input_size=(3, 32, 32))
        print(model_summary)
        print(
            "WARNING: The summary function reports duplicate parameters for multi-GPU case"
        )
    except:
        print("WARNING: Unable to obtain summary of model")

    # name model sub-directory "shift_all" if all layers are converted to shift layers
    conv2d_layers_count = count_layer_type(
        model, nn.Conv2d) + count_layer_type(model,
                                             unoptimized.UnoptimizedConv2d)
    linear_layers_count = count_layer_type(
        model, nn.Linear) + count_layer_type(model,
                                             unoptimized.UnoptimizedLinear)
    if (args.shift_depth > 0):
        if (args.shift_type == 'Q'):
            shift_label = "shift_q"
        else:
            shift_label = "shift_ps"
    else:
        shift_label = "shift"

    if (conv2d_layers_count == 0 and linear_layers_count == 0):
        shift_label += "_all"
    else:
        shift_label += "_%s" % (args.shift_depth)

    if (args.shift_depth > 0):
        shift_label += "_wb_%s" % (args.weight_bits)

    if (args.desc is not None and len(args.desc) > 0):
        desc_label = "_%s" % (args.desc)
    else:
        desc_label = ""

    model_name = '%s/%s%s' % (args.arch, shift_label, desc_label)

    if (args.save_model):
        model_dir = os.path.join(
            os.path.join(os.path.join(os.getcwd(), "models"), "cifar10"),
            model_name)
        if not os.path.isdir(model_dir):
            os.makedirs(model_dir, exist_ok=True)

        with open(os.path.join(model_dir, 'command_args.txt'),
                  'w') as command_args_file:
            for arg, value in sorted(vars(args).items()):
                command_args_file.write(arg + ": " + str(value) + "\n")

        with open(os.path.join(model_dir, 'model_summary.txt'),
                  'w') as summary_file:
            with redirect_stdout(summary_file):
                if (model_summary is not None):
                    print(model_summary)
                    print(
                        "WARNING: The summary function reports duplicate parameters for multi-GPU case"
                    )
                else:
                    print("WARNING: Unable to obtain summary of model")

    # Data loading code
    data_dir = "~/pytorch_datasets"
    os.makedirs(model_dir, exist_ok=True)

    normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                     std=[0.2023, 0.1994, 0.2010])

    train_dataset = datasets.CIFAR10(root=data_dir,
                                     train=True,
                                     transform=transforms.Compose([
                                         transforms.RandomHorizontalFlip(),
                                         transforms.RandomCrop(size=32,
                                                               padding=4),
                                         transforms.ToTensor(),
                                         normalize,
                                     ]),
                                     download=True)

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
        root=data_dir,
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    start_time = time.time()

    if args.evaluate:
        val_log = validate(val_loader, model, criterion, args)
        val_log = [val_log]

        with open(os.path.join(model_dir, "test_log.csv"),
                  "w") as test_log_file:
            test_log_csv = csv.writer(test_log_file)
            test_log_csv.writerow(['test_loss', 'test_top1_acc', 'test_time'])
            test_log_csv.writerows(val_log)
    else:
        train_log = []

        with open(os.path.join(model_dir, "train_log.csv"),
                  "w") as train_log_file:
            train_log_csv = csv.writer(train_log_file)
            train_log_csv.writerow([
                'epoch', 'train_loss', 'train_top1_acc', 'train_time',
                'test_loss', 'test_top1_acc', 'test_time'
            ])

        for epoch in range(args.start_epoch, args.epochs):
            if args.distributed:
                train_sampler.set_epoch(epoch)

            if (args.alternate_update):
                if epoch % 2 == 1:
                    optimizer.param_groups[1]['lr'] = 0
                    optimizer.param_groups[2]['lr'] = optimizer.param_groups[
                        0]['lr']
                else:
                    optimizer.param_groups[1]['lr'] = optimizer.param_groups[
                        0]['lr']
                    optimizer.param_groups[2]['lr'] = 0

            # train for one epoch
            print("current lr ",
                  [param['lr'] for param in optimizer.param_groups])
            train_epoch_log = train(train_loader, model, criterion, optimizer,
                                    epoch, args)
            if (args.lr_schedule):
                lr_scheduler.step()

            # evaluate on validation set
            val_epoch_log = validate(val_loader, model, criterion, args)
            acc1 = val_epoch_log[2]

            # append to log
            with open(os.path.join(model_dir, "train_log.csv"),
                      "a") as train_log_file:
                train_log_csv = csv.writer(train_log_file)
                train_log_csv.writerow(
                    ((epoch, ) + train_epoch_log + val_epoch_log))

            # remember best acc@1 and save checkpoint
            is_best = acc1 > best_acc1
            best_acc1 = max(acc1, best_acc1)

            if (args.print_weights):
                with open(
                        os.path.join(model_dir,
                                     'weights_log_' + str(epoch) + '.txt'),
                        'w') as weights_log_file:
                    with redirect_stdout(weights_log_file):
                        # Log model's state_dict
                        print("Model's state_dict:")
                        # TODO: Use checkpoint above
                        for param_tensor in model.state_dict():
                            print(param_tensor, "\t",
                                  model.state_dict()[param_tensor].size())
                            print(model.state_dict()[param_tensor])
                            print("")

            if not args.multiprocessing_distributed or (
                    args.multiprocessing_distributed
                    and args.rank % ngpus_per_node == 0):
                if is_best:
                    try:
                        if (args.save_model):
                            model_rounded = round_shift_weights(model,
                                                                clone=True)

                            torch.save(model_rounded.state_dict(),
                                       os.path.join(model_dir, "weights.pth"))
                            torch.save(model_rounded,
                                       os.path.join(model_dir, "model.pth"))
                    except:
                        print("WARNING: Unable to save model.pth")

                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'best_acc1': best_acc1,
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler,
                    }, is_best, model_dir)

    end_time = time.time()
    print("Total Time:", end_time - start_time)

    if (args.print_weights):
        if (model_rounded is None):
            model_rounded = round_shift_weights(model, clone=True)

        with open(os.path.join(model_dir, 'weights_log.txt'),
                  'w') as weights_log_file:
            with redirect_stdout(weights_log_file):
                # Log model's state_dict
                print("Model's state_dict:")
                # TODO: Use checkpoint above
                for param_tensor in model_rounded.state_dict():
                    print(param_tensor, "\t",
                          model_rounded.state_dict()[param_tensor].size())
                    print(model_rounded.state_dict()[param_tensor])
                    print("")