def main(): start_epoch = args.start_epoch # start from epoch 0 or last checkpoint epoch # Data print('==> Preparing dataset %s' % args.dataset) transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) if args.dataset == 'cifar10': dataloader = datasets.CIFAR10 num_classes = 10 else: dataloader = datasets.CIFAR100 num_classes = 100 trainset = dataloader(root=args.dataroot, train=True, download=True, transform=transform_train) sampler = torch.utils.data.distributed.DistributedSampler(trainset,num_replicas=hvd.size(), rank=hvd.rank()) trainloader = data.DataLoader(dataset=trainset, batch_size=args.train_batch * world_size, shuffle=False, sampler=sampler) testset = dataloader(root=args.dataroot, train=False, download=False, transform=transform_test) testloader = data.DataLoader(testset, batch_size=args.test_batch * world_size, shuffle=False, num_workers=args.workers) # Model print("==> creating model '{}'".format("Alexnet")) model = AlexNet(num_classes=num_classes) device = torch.device('cuda', local_rank) model = model.to(device) # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) print('Model on cuda:%d' % local_rank) print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # 用horovod封装优化器 optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters()) # 广播参数 hvd.broadcast_parameters(model.state_dict(), root_rank=0) # Train and val for epoch in range(start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch) train_loss, train_acc = train(trainloader, model, criterion, optimizer, epoch, use_cuda) test_loss, test_acc = test(testloader, model, criterion, epoch, use_cuda) print('Rank:{} Epoch[{}/{}]: LR: {:.3f}, Train loss: {:.5f}, Test loss: {:.5f}, Train acc: {:.2f}, Test acc: {:.2f}.'.format(local_rank,epoch+1, args.epochs, state['lr'], train_loss, test_loss, train_acc, test_acc))
alexnet_dict = alexnet.state_dict() # print(alexnet_dict.keys()) alexnet_pretrained = models.alexnet(pretrained=True) pretrained_dict = alexnet_pretrained.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in alexnet_dict} # print(pretrained_dict.keys()) pretrained_dict.pop("classifier.6.weight") pretrained_dict.pop("classifier.6.bias") alexnet_dict.update(pretrained_dict) alexnet.load_state_dict(alexnet_dict) # print(alexnet_dict.keys()) print("Load from pretrained") # Freeze parameter if freeze_layer: for name, value in alexnet.named_parameters(): if (name != "classifier.6.weight") and (name != "classifier.6.bias"): value.requires_grad = False print("Freeze layer") # train on multiple GPUs DEVICE_IDS = list(range(GPU_NUM)) # alexnet = alexnet.to(device) if GPU_NUM > 1: alexnet = torch.nn.parallel.DataParallel(alexnet, device_ids=DEVICE_IDS) alexnet = alexnet.to(device) print(alexnet) print("Network created") # data normalization normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],