def train(data_train, data_val, num_classes, num_epoch, milestones): model = AlexNet(num_classes, pretrain=False) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = model.to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.0001) lr_scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.1) since = time.time() best_acc = 0 best = 0 for epoch in range(num_epoch): print('Epoch {}/{}'.format(epoch + 1, num_epoch)) print('-' * 10) # Iterate over data. running_loss = 0.0 running_corrects = 0 model.train() with torch.set_grad_enabled(True): for i, (inputs, labels) in enumerate(data_train): inputs = inputs.to(device) labels = labels.to(device) outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() running_corrects += torch.sum(preds == labels.data) * 1. / inputs.size(0) print("\rIteration: {}/{}, Loss: {}.".format(i + 1, len(data_train), loss.item()), end="") sys.stdout.flush() avg_loss = running_loss / len(data_train) t_acc = running_corrects.double() / len(data_train) running_loss = 0.0 running_corrects = 0 model.eval() with torch.set_grad_enabled(False): for i, (inputs, labels) in enumerate(data_val): inputs = inputs.to(device) labels = labels.to(device) outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) running_loss += loss.item() running_corrects += torch.sum(preds == labels.data) * 1. / inputs.size(0) val_loss = running_loss / len(data_val) val_acc = running_corrects.double() / len(data_val) print() print('Train Loss: {:.4f} Acc: {:.4f}'.format(avg_loss, t_acc)) print('Val Loss: {:.4f} Acc: {:.4f}'.format(val_loss, val_acc)) print('lr rate: {:.6f}'.format(optimizer.param_groups[0]['lr'])) print() if val_acc > best_acc: best_acc = val_acc best = epoch + 1 lr_scheduler.step() time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) print('Best Validation Accuracy: {}, Epoch: {}'.format(best_acc, best)) return model
data_preprocess=valid_data_preprocess) # test_loader = cfg.dataset_loader(root=cfg.cat_dog_test, train=False, shuffle=False, # data_preprocess=valid_data_preprocess) # ---------------构建网络、定义损失函数、优化器-------------------------- # 构建网络结构 # net = resnet() net = AlexNet(num_classes=cfg.num_classes) # net = resnet50() #net = resnet18() # 重写网络最后一层 #fc_in_features = net.fc.in_features # 网络最后一层的输入通道 #net.fc = nn.Linear(in_features=fc_in_features, out_features=cfg.num_classes) # 将网络结构、损失函数放置在GPU上;配置优化器 net = net.to(cfg.device) # net = nn.DataParallel(net, device_ids=[0, 1]) # criterion=nn.BCELoss() #criterion = nn.BCEWithLogitsLoss().cuda(device=cfg.device) criterion = nn.CrossEntropyLoss().cuda(device=cfg.device) # 常规优化器:随机梯度下降和Adam #optimizer = optim.SGD(params=net.parameters(), lr=cfg.learning_rate, # weight_decay=cfg.weight_decay, momentum=cfg.momentum) optimizer = optim.Adam(params=net.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay) # 线性学习率优化器 #optimizer = optim.SGD(params=net.parameters(), lr=cfg.learning, # weight_decay=cfg.weight_decay, momentum=cfg.momentum) # --------------进行训练----------------- # print('进行训练....')
def main(): parser = argparse.ArgumentParser() parser.add_argument('--stage', default='train', type=str) parser.add_argument('--dataset', default='imagenet', type=str) parser.add_argument('--lr', default=0.0012, type=float) parser.add_argument('--batch_size', default=128, type=int) parser.add_argument('--gpus', default='0,1,2,3', type=str) parser.add_argument('--weight_decay', default=1e-5, type=float) parser.add_argument('--max_epoch', default=30, type=int) parser.add_argument('--lr_decay_steps', default='15,20,25', type=str) parser.add_argument('--exp', default='', type=str) parser.add_argument('--list', default='', type=str) parser.add_argument('--resume_path', default='', type=str) parser.add_argument('--pretrain_path', default='', type=str) parser.add_argument('--n_workers', default=32, type=int) parser.add_argument('--network', default='resnet50', type=str) global args args = parser.parse_args() if not os.path.exists(args.exp): os.makedirs(args.exp) if not os.path.exists(os.path.join(args.exp, 'runs')): os.makedirs(os.path.join(args.exp, 'runs')) if not os.path.exists(os.path.join(args.exp, 'models')): os.makedirs(os.path.join(args.exp, 'models')) if not os.path.exists(os.path.join(args.exp, 'logs')): os.makedirs(os.path.join(args.exp, 'logs')) # logger initialize logger = getLogger(args.exp) device_ids = list(map(lambda x: int(x), args.gpus.split(','))) device = torch.device('cuda: 0') train_loader, val_loader = cifar.get_semi_dataloader( args) if args.dataset.startswith( 'cifar') else imagenet.get_semi_dataloader(args) # create model if args.network == 'alexnet': network = AlexNet(128) elif args.network == 'alexnet_cifar': network = AlexNet_cifar(128) elif args.network == 'resnet18_cifar': network = ResNet18_cifar() elif args.network == 'resnet50_cifar': network = ResNet50_cifar() elif args.network == 'wide_resnet28': network = WideResNet(28, args.dataset == 'cifar10' and 10 or 100, 2) elif args.network == 'resnet18': network = resnet18() elif args.network == 'resnet50': network = resnet50() network = nn.DataParallel(network, device_ids=device_ids) network.to(device) classifier = nn.Linear(2048, 1000).to(device) # create optimizer parameters = network.parameters() optimizer = torch.optim.SGD( parameters, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay, ) cls_optimizer = torch.optim.SGD( classifier.parameters(), lr=args.lr * 50, momentum=0.9, weight_decay=args.weight_decay, ) cudnn.benchmark = True # create memory_bank global writer writer = SummaryWriter(comment='SemiSupervised', logdir=os.path.join(args.exp, 'runs')) # create criterion criterion = nn.CrossEntropyLoss() logging.info(beautify(args)) start_epoch = 0 if args.pretrain_path != '' and args.pretrain_path != 'none': logging.info('loading pretrained file from {}'.format( args.pretrain_path)) checkpoint = torch.load(args.pretrain_path) state_dict = checkpoint['state_dict'] valid_state_dict = { k: v for k, v in state_dict.items() if k in network.state_dict() and 'fc.' not in k } for k, v in network.state_dict().items(): if k not in valid_state_dict: logging.info('{}: Random Init'.format(k)) valid_state_dict[k] = v # logging.info(valid_state_dict.keys()) network.load_state_dict(valid_state_dict) else: logging.info('Training SemiSupervised Learning From Scratch') logging.info('start training') best_acc = 0.0 try: for i_epoch in range(start_epoch, args.max_epoch): train(i_epoch, network, classifier, criterion, optimizer, cls_optimizer, train_loader, device) checkpoint = { 'epoch': i_epoch + 1, 'state_dict': network.state_dict(), 'optimizer': optimizer.state_dict(), } torch.save(checkpoint, os.path.join(args.exp, 'models', 'checkpoint.pth')) adjust_learning_rate(args.lr_decay_steps, optimizer, i_epoch) if i_epoch % 2 == 0: acc1, acc5 = validate(i_epoch, network, classifier, val_loader, device) if acc1 >= best_acc: best_acc = acc1 torch.save(checkpoint, os.path.join(args.exp, 'models', 'best.pth')) writer.add_scalar('acc1', acc1, i_epoch + 1) writer.add_scalar('acc5', acc5, i_epoch + 1) if i_epoch in [30, 60, 120, 160, 200]: torch.save( checkpoint, os.path.join(args.exp, 'models', '{}.pth'.format(i_epoch + 1))) logging.info( colorful('[Epoch: {}] val acc: {:.4f}/{:.4f}'.format( i_epoch, acc1, acc5))) logging.info( colorful('[Epoch: {}] best acc: {:.4f}'.format( i_epoch, best_acc))) with torch.no_grad(): for name, param in network.named_parameters(): if 'bn' not in name: writer.add_histogram(name, param, i_epoch) # cluster except KeyboardInterrupt as e: logging.info('KeyboardInterrupt at {} Epochs'.format(i_epoch)) exit()