Ejemplo n.º 1
0
def train(data_train, data_val, num_classes, num_epoch, milestones):
    model = AlexNet(num_classes, pretrain=False)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    lr_scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.1)

    since = time.time()
    best_acc = 0
    best = 0
    for epoch in range(num_epoch):
        print('Epoch {}/{}'.format(epoch + 1, num_epoch))
        print('-' * 10)


        # Iterate over data.
        running_loss = 0.0
        running_corrects = 0
        model.train()
        with torch.set_grad_enabled(True):
            for i, (inputs, labels) in enumerate(data_train):
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                running_corrects += torch.sum(preds == labels.data) * 1. / inputs.size(0)
                print("\rIteration: {}/{}, Loss: {}.".format(i + 1, len(data_train), loss.item()), end="")

                sys.stdout.flush()

        avg_loss = running_loss / len(data_train)
        t_acc = running_corrects.double() / len(data_train)

        running_loss = 0.0
        running_corrects = 0
        model.eval()
        with torch.set_grad_enabled(False):
            for i, (inputs, labels) in enumerate(data_val):
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                running_loss += loss.item()
                running_corrects += torch.sum(preds == labels.data) * 1. / inputs.size(0)

        val_loss = running_loss / len(data_val)
        val_acc = running_corrects.double() / len(data_val)

        print()
        print('Train Loss: {:.4f} Acc: {:.4f}'.format(avg_loss, t_acc))
        print('Val Loss: {:.4f} Acc: {:.4f}'.format(val_loss, val_acc))
        print('lr rate: {:.6f}'.format(optimizer.param_groups[0]['lr']))
        print()

        if val_acc > best_acc:
            best_acc = val_acc
            best = epoch + 1

        lr_scheduler.step()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best Validation Accuracy: {}, Epoch: {}'.format(best_acc, best))

    return model
Ejemplo n.º 2
0
                                  data_preprocess=valid_data_preprocess)
# test_loader = cfg.dataset_loader(root=cfg.cat_dog_test, train=False, shuffle=False,
#                                  data_preprocess=valid_data_preprocess)

# ---------------构建网络、定义损失函数、优化器--------------------------
# 构建网络结构
# net = resnet()
net = AlexNet(num_classes=cfg.num_classes)
# net = resnet50()
#net = resnet18()
# 重写网络最后一层
#fc_in_features = net.fc.in_features  # 网络最后一层的输入通道
#net.fc = nn.Linear(in_features=fc_in_features, out_features=cfg.num_classes)

# 将网络结构、损失函数放置在GPU上;配置优化器
net = net.to(cfg.device)
# net = nn.DataParallel(net, device_ids=[0, 1])
# criterion=nn.BCELoss()
#criterion = nn.BCEWithLogitsLoss().cuda(device=cfg.device)
criterion = nn.CrossEntropyLoss().cuda(device=cfg.device)
# 常规优化器:随机梯度下降和Adam
#optimizer = optim.SGD(params=net.parameters(), lr=cfg.learning_rate,
#                      weight_decay=cfg.weight_decay, momentum=cfg.momentum)
optimizer = optim.Adam(params=net.parameters(), lr=cfg.learning_rate,
                       weight_decay=cfg.weight_decay)
# 线性学习率优化器
#optimizer = optim.SGD(params=net.parameters(), lr=cfg.learning,
                     # weight_decay=cfg.weight_decay, momentum=cfg.momentum)

# --------------进行训练-----------------
# print('进行训练....')
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--stage', default='train', type=str)
    parser.add_argument('--dataset', default='imagenet', type=str)
    parser.add_argument('--lr', default=0.0012, type=float)
    parser.add_argument('--batch_size', default=128, type=int)
    parser.add_argument('--gpus', default='0,1,2,3', type=str)
    parser.add_argument('--weight_decay', default=1e-5, type=float)
    parser.add_argument('--max_epoch', default=30, type=int)
    parser.add_argument('--lr_decay_steps', default='15,20,25', type=str)
    parser.add_argument('--exp', default='', type=str)
    parser.add_argument('--list', default='', type=str)
    parser.add_argument('--resume_path', default='', type=str)
    parser.add_argument('--pretrain_path', default='', type=str)
    parser.add_argument('--n_workers', default=32, type=int)

    parser.add_argument('--network', default='resnet50', type=str)

    global args
    args = parser.parse_args()

    if not os.path.exists(args.exp):
        os.makedirs(args.exp)
    if not os.path.exists(os.path.join(args.exp, 'runs')):
        os.makedirs(os.path.join(args.exp, 'runs'))
    if not os.path.exists(os.path.join(args.exp, 'models')):
        os.makedirs(os.path.join(args.exp, 'models'))
    if not os.path.exists(os.path.join(args.exp, 'logs')):
        os.makedirs(os.path.join(args.exp, 'logs'))

    # logger initialize
    logger = getLogger(args.exp)

    device_ids = list(map(lambda x: int(x), args.gpus.split(',')))
    device = torch.device('cuda: 0')

    train_loader, val_loader = cifar.get_semi_dataloader(
        args) if args.dataset.startswith(
            'cifar') else imagenet.get_semi_dataloader(args)

    # create model
    if args.network == 'alexnet':
        network = AlexNet(128)
    elif args.network == 'alexnet_cifar':
        network = AlexNet_cifar(128)
    elif args.network == 'resnet18_cifar':
        network = ResNet18_cifar()
    elif args.network == 'resnet50_cifar':
        network = ResNet50_cifar()
    elif args.network == 'wide_resnet28':
        network = WideResNet(28, args.dataset == 'cifar10' and 10 or 100, 2)
    elif args.network == 'resnet18':
        network = resnet18()
    elif args.network == 'resnet50':
        network = resnet50()
    network = nn.DataParallel(network, device_ids=device_ids)
    network.to(device)

    classifier = nn.Linear(2048, 1000).to(device)
    # create optimizer

    parameters = network.parameters()
    optimizer = torch.optim.SGD(
        parameters,
        lr=args.lr,
        momentum=0.9,
        weight_decay=args.weight_decay,
    )

    cls_optimizer = torch.optim.SGD(
        classifier.parameters(),
        lr=args.lr * 50,
        momentum=0.9,
        weight_decay=args.weight_decay,
    )

    cudnn.benchmark = True

    # create memory_bank
    global writer
    writer = SummaryWriter(comment='SemiSupervised',
                           logdir=os.path.join(args.exp, 'runs'))

    # create criterion
    criterion = nn.CrossEntropyLoss()

    logging.info(beautify(args))
    start_epoch = 0
    if args.pretrain_path != '' and args.pretrain_path != 'none':
        logging.info('loading pretrained file from {}'.format(
            args.pretrain_path))
        checkpoint = torch.load(args.pretrain_path)
        state_dict = checkpoint['state_dict']
        valid_state_dict = {
            k: v
            for k, v in state_dict.items()
            if k in network.state_dict() and 'fc.' not in k
        }
        for k, v in network.state_dict().items():
            if k not in valid_state_dict:
                logging.info('{}: Random Init'.format(k))
                valid_state_dict[k] = v
        # logging.info(valid_state_dict.keys())
        network.load_state_dict(valid_state_dict)
    else:
        logging.info('Training SemiSupervised Learning From Scratch')

    logging.info('start training')
    best_acc = 0.0
    try:
        for i_epoch in range(start_epoch, args.max_epoch):
            train(i_epoch, network, classifier, criterion, optimizer,
                  cls_optimizer, train_loader, device)

            checkpoint = {
                'epoch': i_epoch + 1,
                'state_dict': network.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            torch.save(checkpoint,
                       os.path.join(args.exp, 'models', 'checkpoint.pth'))
            adjust_learning_rate(args.lr_decay_steps, optimizer, i_epoch)
            if i_epoch % 2 == 0:
                acc1, acc5 = validate(i_epoch, network, classifier, val_loader,
                                      device)
                if acc1 >= best_acc:
                    best_acc = acc1
                    torch.save(checkpoint,
                               os.path.join(args.exp, 'models', 'best.pth'))
                writer.add_scalar('acc1', acc1, i_epoch + 1)
                writer.add_scalar('acc5', acc5, i_epoch + 1)

            if i_epoch in [30, 60, 120, 160, 200]:
                torch.save(
                    checkpoint,
                    os.path.join(args.exp, 'models',
                                 '{}.pth'.format(i_epoch + 1)))

            logging.info(
                colorful('[Epoch: {}] val acc: {:.4f}/{:.4f}'.format(
                    i_epoch, acc1, acc5)))
            logging.info(
                colorful('[Epoch: {}] best acc: {:.4f}'.format(
                    i_epoch, best_acc)))

            with torch.no_grad():
                for name, param in network.named_parameters():
                    if 'bn' not in name:
                        writer.add_histogram(name, param, i_epoch)

            # cluster
    except KeyboardInterrupt as e:
        logging.info('KeyboardInterrupt at {} Epochs'.format(i_epoch))
        exit()