Exemple #1
0
    val_loader = DataLoader(val_dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=workers,
                            collate_fn=val_dataset.collate_fn,
                            pin_memory=True)
    optimizer = torch.optim.SGD(net.parameters(),
                                lr,
                                momentum=0.9,
                                weight_decay=1e-4)
    # optimizer = torch.optim.Adam(net.parameters(), lr, weight_decay=1e-4)

    lrs = LRScheduler(lr,
                      epochs,
                      patience=3,
                      factor=0.1,
                      min_lr=1e-5,
                      early_stop=5,
                      best_loss=best_val_loss)
    for epoch in range(start_epoch, epochs + 1):
        train_metrics, train_time = train(train_loader, net, loss, optimizer,
                                          lr)
        val_metrics, val_time = validate(val_loader, net, loss)

        print_log(epoch, lr, train_metrics, train_time, val_metrics, val_time,
                  save_dir, log_mode)

        val_loss = np.mean(val_metrics[:, 0])
        lr = lrs.update_by_rule(val_loss)
        if val_loss < best_val_loss or epoch % 10 == 0 or lr is None:
            if val_loss < best_val_loss:
Exemple #2
0
def train_model(model, num_epoch):
    since = time.time()
    lr_list = []
    for epoch in range(num_epoch):
        print('Epoh {}/{}'.format(epoch, num_epoch - 1))
        print('-' * 20)
        running_loss = 0.0
        running_corrects = 0.0

        lr_scheduler = LRScheduler(base_lr=0.05,
                                   step=[50, 80],
                                   factor=0.1,
                                   warmup_epoch=10,
                                   warmup_begin_lr=3e-4)
        lr = lr_scheduler.update(epoch)
        lr_list.append(lr)
        print(lr)

        ignored_params = list(map(id, model.module.liner1.parameters()))
        # map(id,model.module.liner2.parameters())))
        ignored_params += (list(map(id, model.module.liner2.parameters())))

        base_params = filter(lambda p: id(p) not in ignored_params,
                             model.module.parameters())
        optimizer = optim.SGD([{
            'params': base_params,
            'lr': 0.1 * lr
        }, {
            'params': model.module.liner1.parameters(),
            'lr': lr
        }, {
            'params': model.module.liner2.parameters(),
            'lr': lr
        }],
                              weight_decay=5e-4,
                              momentum=0.9,
                              nesterov=True)

        #optimizer = torch.optim.SGD(model.parameters(),lr)

        for data in train_loader:
            inputs, labels = data
            inputs = inputs.cuda()
            labels = labels.cuda()

            now_batch_size, c, h, w = inputs.shape
            if now_batch_size < 32:
                continue
            inputs, labels = Variable(inputs), Variable(labels)

            criterion = nn.CrossEntropyLoss()
            optimizer.zero_grad()
            out = model(inputs)
            loss = criterion(out, labels)
            running_loss += loss
            _, preds = torch.max(out.data, 1)
            running_corrects += float(torch.sum(preds == labels.data))
            epoch_acc = running_corrects / dataset_sizes
            loss.backward()
            optimizer.step()
        print('Epoch:{}   Loss: {:.4f}  acc: {:.4f} '.format(
            epoch, running_loss, epoch_acc))
        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
        print()

    plt.plot(lr_list)
    plt.show()
    print('Finished training....')
Exemple #3
0
def train_model(model, criterion, triplet, num_epochs):
    since = time.time()

    best_model_wts = model.state_dict()
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        # update learning rate
        lr_scheduler = LRScheduler(base_lr=3e-2,
                                   step=[60, 130],
                                   factor=0.1,
                                   warmup_epoch=10,
                                   warmup_begin_lr=3e-4)

        lr = lr_scheduler.update(epoch)
        optimizer = optim.SGD(model.parameters(),
                              lr=lr,
                              weight_decay=5e-4,
                              momentum=0.9,
                              nesterov=True)
        print(lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        # Each epoch has a training and validation phase
        for phase in ['train']:
            if phase == 'train':
                #scheduler.step()
                model.train(True)  # Set model to training mode
            else:
                model.train(False)  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0.0
            # Iterate over data.
            for data in dataloaders[phase]:
                # get the inputs
                inputs, labels = data
                now_batch_size, c, h, w = inputs.shape
                if now_batch_size < batchsize:  # skip the last batch
                    continue
                #print(inputs.shape)
                # wrap them in Variable
                if use_gpu:
                    inputs = inputs.cuda()
                    labels = labels.cuda()
                else:
                    inputs, labels = Variable(inputs), Variable(labels)
                temp_loss = []
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward

                outputs1, outputs2, outputs3, q1, q2, q3, q4, q5, q6 = model(
                    inputs)
                #_, preds = torch.max(outputs.data, 1)
                _, preds1 = torch.max(outputs1.data, 1)
                _, preds2 = torch.max(outputs2.data, 1)
                _, preds3 = torch.max(outputs3.data, 1)
                #

                loss1 = criterion(outputs1, labels)
                loss2 = criterion(outputs2, labels)
                loss3 = criterion(outputs3, labels)
                #
                loss5 = triplet(q1, labels)[0]
                loss6 = triplet(q2, labels)[0]
                loss7 = triplet(q3, labels)[0]
                loss8 = triplet(q4, labels)[0]
                loss9 = triplet(q5, labels)[0]
                loss10 = triplet(q6, labels)[0]

                #
                temp_loss.append(loss1)
                temp_loss.append(loss2)
                temp_loss.append(loss3)
                #
                loss = sum(temp_loss) / 3 + (loss5 + loss6 + loss7 + loss8 +
                                             loss9 + loss10) / 6
                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

                # statistics
                if int(version[2]
                       ) > 3:  # for the new version like 0.4.0 and 0.5.0
                    running_loss += loss.item() * now_batch_size
                else:  # for the old version like 0.3.0 and 0.3.1
                    running_loss += loss.data[0] * now_batch_size
                a = float(torch.sum(preds1 == labels.data))
                b = float(torch.sum(preds2 == labels.data))
                c = float(torch.sum(preds3 == labels.data))
                #

                running_corrects_1 = a + b + c
                running_corrects_2 = running_corrects_1 / 3
                running_corrects += running_corrects_2
                #running_corrects +=float(torch.sum(preds == labels.data))

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects / dataset_sizes[phase]
            # 在日志文件中记录每个epoch的精度和loss
            with open('./model/%s/%s.txt' % (name, name), 'a') as acc_file:
                acc_file.write('Epoch: %2d, Precision: %.8f, Loss: %.8f\n' %
                               (epoch, epoch_acc, epoch_loss))
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss,
                                                       epoch_acc))

            y_loss[phase].append(epoch_loss)
            y_err[phase].append(1.0 - epoch_acc)
            # deep copy the model
            if phase == 'train':
                last_model_wts = model.state_dict()
                if epoch < 150:
                    if epoch % 10 == 9:
                        save_network(model, epoch)
                    draw_curve(epoch)
                else:
                    #if epoch%2 == 0:
                    save_network(model, epoch)
                    draw_curve(epoch)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    #print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(last_model_wts)
    save_network(model, 'last')
    return model
Exemple #4
0
def main_worker(gpu, args):
    global best_acc1
    global best_auc
    global minimum_loss
    global count
    global best_accdr
    args.gpu = gpu


    if not os.path.exists(args.model_dir):
        os.makedirs(args.model_dir)

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.arch == "vgg11":
        from models.vgg import vgg11
        model = vgg11(num_classes=args.num_class, crossCBAM=args.crossCBAM)
    elif args.arch == "resnet50":
        from models.resnet50 import resnet50
        model = resnet50(num_classes=args.num_class, multitask=args.multitask, liu=args.liu,
                 chen=args.chen, CAN_TS=args.CAN_TS, crossCBAM=args.crossCBAM,
                         crosspatialCBAM = args.crosspatialCBAM,  choice=args.choice)
    elif args.arch == "resnet34":
        from models.resnet50 import resnet34
        model = resnet34(num_classes=args.num_class, multitask=args.multitask, liu=args.liu,
                 chen=args.chen,CAN_TS=args.CAN_TS, crossCBAM=args.crossCBAM,
                         crosspatialCBAM = args.crosspatialCBAM)
    elif args.arch == "resnet18":
        from models.resnet50 import resnet18
        model = resnet18(num_classes=args.num_class, multitask=args.multitask, liu=args.liu,
                 chen=args.chen, flagCBAM=False, crossCBAM=args.crossCBAM)
    elif args.arch == "densenet161":
        from models.densenet import densenet161
        model = densenet161(num_classes=args.num_class, multitask=args.multitask, cosface=False, liu=args.liu,
                    chen=args.chen, crossCBAM=args.crossCBAM)
    elif args.arch == "wired":
        from models.wirednetwork import CNN
        model = CNN(args, num_classes=args.num_class)
    else:
        print ("no backbone model")

    if args.pretrained:
        print ("==> Load pretrained model")
        model_dict = model.state_dict()
        pretrain_path = {"resnet50": "pretrain/resnet50-19c8e357.pth",
                         "resnet34": "pretrain/resnet34-333f7ec4.pth",
                         "resnet18": "pretrain/resnet18-5c106cde.pth",
                         "densenet161": "pretrain/densenet161-8d451a50.pth",
                         "vgg11": "pretrain/vgg11-bbd30ac9.pth",
                         "densenet121": "pretrain/densenet121-a639ec97.pth"}[args.arch]
        pretrained_dict = torch.load(pretrain_path)
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        pretrained_dict.pop('classifier.weight', None)
        pretrained_dict.pop('classifier.bias', None)
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

    torch.cuda.set_device(args.gpu)
    model = model.cuda(args.gpu)



    print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0))

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    if args.adam:
        optimizer = torch.optim.Adam(model.parameters(), args.base_lr, weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.SGD(model.parameters(), args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume,  map_location={'cuda:4':'cuda:0'})
            # args.start_epoch = checkpoint['epoch']

            #  load partial weights
            if not args.evaluate:
                print ("load partial weights")
                model_dict = model.state_dict()
                pretrained_dict = {k: v for k, v in checkpoint['state_dict'].items() if k in model_dict}
                model_dict.update(pretrained_dict)
                model.load_state_dict(model_dict)
            else:
                print("load whole weights")
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])

            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            exit(0)


    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    size  = 224

    tra = transforms.Compose([
                transforms.Resize(256),
                transforms.RandomResizedCrop(size),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                # transforms.RandomRotation(90),
                # transforms.ColorJitter(0.05, 0.05, 0.05, 0.05),
                transforms.ToTensor(),
                normalize,
            ])
    tra_test = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize])

    # tra = transforms.Compose([
    #     transforms.Resize(350),
    #     transforms.RandomHorizontalFlip(),
    #     transforms.RandomVerticalFlip(),
    #     # transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
    #     transforms.RandomRotation([-180, 180]),
    #     transforms.RandomAffine([-180, 180], translate=[0.1, 0.1], scale=[0.7, 1.3]),
    #     transforms.RandomCrop(224),
    #     #            transforms.CenterCrop(224),
    #     transforms.ToTensor(),
    #     normalize
    # ])

    # print (args.model_dir)
    # tra = transforms.Compose([
    #     transforms.Resize(350),
    #     transforms.RandomHorizontalFlip(),
    #     transforms.RandomVerticalFlip(),
    #     # transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
    #     transforms.RandomRotation([-180, 180]),
    #     transforms.RandomAffine([-180, 180], translate=[0.1, 0.1], scale=[0.7, 1.3]),
    #     transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    #     transforms.ToTensor(),
    #     normalize
    # ])
    # tra_test = transforms.Compose([
    #     transforms.Resize(350),
    #     transforms.CenterCrop(224),
    #     transforms.ToTensor(),
    #     normalize])

    if args.dataset == 'amd':
        from datasets.amd_dataset import traindataset
    elif args.dataset == 'pm':
        from datasets.pm_dataset import traindataset
    elif args.dataset == "drdme":
        from datasets.drdme_dataset import traindataset
    elif args.dataset == "missidor":
        from datasets.missidor import traindataset
    elif args.dataset == "kaggle":
        from datasets.kaggle import traindataset
    else:
        print ("no dataset")
        exit(0)

    val_dataset = traindataset(root=args.data, mode = 'val',
                               transform=tra_test, num_class=args.num_class,
                               multitask=args.multitask, args=args)



    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)



    if args.evaluate:
        a = time.time()
        # savedir = args.resume.replace("model_converge.pth.tar","")
        savedir = args.resume.replace(args.resume.split("/")[-1], "")
        # savedir = "./"
        if not args.multitask:
            acc, auc, precision_dr, recall_dr, f1score_dr  = validate(val_loader, model, args)
            result_list = [acc, auc, precision_dr, recall_dr, f1score_dr]
            print ("acc, auc, precision, recall, f1", acc, auc, precision_dr, recall_dr, f1score_dr)

            save_result_txt(savedir, result_list)
            print("time", time.time() - a)
            return
        else:
            acc_dr, acc_dme, acc_joint, other_results, se, sp = validate(val_loader, model, args)
            print ("acc_dr, acc_dme, acc_joint", acc_dr, acc_dme, acc_joint)
            exit(0)
            print ("auc_dr, auc_dme, precision_dr, precision_dme, recall_dr, recall_dme, f1score_dr, f1score_dme",
                   other_results)
            print ("se, sp", se, sp)
            result_list = [acc_dr, acc_dme, acc_joint]
            result_list += other_results
            result_list += [se, sp]
            save_result_txt(savedir, result_list)

            print ("time", time.time()-a)
            return

    train_dataset = traindataset(root=args.data, mode='train', transform=tra, num_class=args.num_class,
                                 multitask=args.multitask, args=args)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
            num_workers=args.workers, pin_memory=True,worker_init_fn=worker_init_fn)


    writer = SummaryWriter()
    writer.add_text('Text', str(args))
    #
    from lr_scheduler import LRScheduler
    lr_scheduler = LRScheduler(optimizer, len(train_loader), args)

    for epoch in range(args.start_epoch, args.epochs):
        is_best = False
        is_best_auc = False
        is_best_acc = False
        # lr = adjust_learning_rate(optimizer, epoch, args)
        # writer.add_scalar("lr", lr, epoch)
        # train for one epoch
        loss_train = train(train_loader, model, criterion, lr_scheduler, writer, epoch, optimizer, args)
        writer.add_scalar('Train loss', loss_train, epoch)

        # evaluate on validation set
        if epoch % 5 == 0:
            if args.dataset == "kaggle":
                acc_dr, auc_dr = validate(val_loader, model, args)
                writer.add_scalar("Val acc_dr", acc_dr, epoch)
                writer.add_scalar("Val auc_dr", auc_dr, epoch)
                is_best = acc_dr >= best_acc1
                best_acc1 = max(acc_dr, best_acc1)
            elif not args.multitask:
                acc, auc, precision, recall, f1 = validate(val_loader, model, args)
                writer.add_scalar("Val acc_dr", acc, epoch)
                writer.add_scalar("Val auc_dr", auc, epoch)
                is_best = auc >= best_acc1
                best_acc1 = max(auc, best_acc1)
            else:
                acc_dr, acc_dme, joint_acc, other_results, se, sp , losses = validate(val_loader, model, args,criterion)
                writer.add_scalar("Val acc_dr", acc_dr, epoch)
                writer.add_scalar("Val acc_dme", acc_dme, epoch)
                writer.add_scalar("Val acc_joint", joint_acc, epoch)
                writer.add_scalar("Val auc_dr", other_results[0], epoch)
                writer.add_scalar("Val auc_dme", other_results[1], epoch)
                writer.add_scalar("val loss", losses, epoch)
                is_best = joint_acc >= best_acc1
                best_acc1 = max(joint_acc, best_acc1)

                is_best_auc = other_results[0] >= best_auc
                best_auc = max(other_results[0], best_auc)

                is_best_acc = acc_dr >= best_accdr
                best_accdr = max(acc_dr, best_accdr)

        if not args.invalid:
            if is_best:
                save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer' : optimizer.state_dict(),
                }, is_best, filename = "model_converge.pth.tar", save_dir=args.model_dir)

            if is_best_auc:
                save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_auc,
                'optimizer' : optimizer.state_dict(),
                }, False, filename = "converge_auc.pth.tar", save_dir=args.model_dir)

            if is_best_acc:
                save_checkpoint({
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_accdr,
                    'optimizer': optimizer.state_dict(),
                }, False, filename="converge_acc.pth.tar", save_dir=args.model_dir)
Exemple #5
0
def train_model(model, criterion, triplet, num_epochs):
    since = time.time()
    # best_model_wts = model.state_dict()
    # best_acc = 0.0
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        # update learning rate
        lr_scheduler = LRScheduler(base_lr=opt.base_lr,
                                   step=[60, 90, 120],
                                   factor=opt.factor,
                                   warmup_epoch=opt.warm_epoch,
                                   warmup_begin_lr=opt.warmup_begin_lr)
        lr = lr_scheduler.update(epoch)
        optimizer = optim.SGD(model.parameters(),
                              lr=lr,
                              weight_decay=5e-4,
                              momentum=0.9,
                              nesterov=True)
        print(lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        # Each epoch has a training and validation phase
        for phase in ['train']:
            if phase == 'train':
                #scheduler.step()
                model.train(True)  # Set model to training mode
            else:
                model.train(False)  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0.0
            # Iterate over data.
            for data in dataloaders[phase]:
                # get the inputs
                inputs, labels = data
                now_batch_size, c, h, w = inputs.shape
                if now_batch_size < opt.batchsize:  # skip the last batch
                    continue
                # wrap them in Variable
                if use_gpu:
                    inputs = inputs.cuda()
                    labels = labels.cuda()
                else:
                    inputs, labels = Variable(inputs), Variable(labels)
                temp_loss = []
                # zero the parameter gradients
                optimizer.zero_grad()
                # forward
                outputs1, outputs2, outputs3, outputs4, outputs5, outputs6, outputs7, outputs8, outputs9, outputs10, outputs11, outputs12, outputs13, outputs14, outputs15, outputs16, outputs17, outputs18, outputs19, outputs20, outputs21, q1, q2, q3, q4, q5, q6 = model(
                    inputs)
                _, preds1 = torch.max(outputs1.data, 1)
                _, preds2 = torch.max(outputs2.data, 1)
                _, preds3 = torch.max(outputs3.data, 1)
                _, preds4 = torch.max(outputs4.data, 1)
                _, preds5 = torch.max(outputs5.data, 1)
                _, preds6 = torch.max(outputs6.data, 1)
                _, preds7 = torch.max(outputs7.data, 1)
                _, preds8 = torch.max(outputs8.data, 1)
                _, preds9 = torch.max(outputs9.data, 1)
                _, preds10 = torch.max(outputs10.data, 1)
                _, preds11 = torch.max(outputs11.data, 1)
                _, preds12 = torch.max(outputs12.data, 1)
                _, preds13 = torch.max(outputs13.data, 1)
                _, preds14 = torch.max(outputs14.data, 1)
                _, preds15 = torch.max(outputs15.data, 1)
                _, preds16 = torch.max(outputs16.data, 1)
                _, preds17 = torch.max(outputs17.data, 1)
                _, preds18 = torch.max(outputs18.data, 1)
                _, preds19 = torch.max(outputs19.data, 1)
                _, preds20 = torch.max(outputs20.data, 1)
                _, preds21 = torch.max(outputs21.data, 1)

                loss1 = criterion(outputs1, labels)
                loss2 = criterion(outputs2, labels)
                loss3 = criterion(outputs3, labels)
                loss4 = criterion(outputs4, labels)
                loss5 = criterion(outputs5, labels)
                loss6 = criterion(outputs6, labels)
                loss7 = criterion(outputs7, labels)
                loss8 = criterion(outputs8, labels)
                loss9 = criterion(outputs9, labels)
                loss10 = criterion(outputs10, labels)
                loss11 = criterion(outputs11, labels)
                loss12 = criterion(outputs12, labels)
                loss13 = criterion(outputs13, labels)
                loss14 = criterion(outputs14, labels)
                loss15 = criterion(outputs15, labels)
                loss16 = criterion(outputs16, labels)
                loss17 = criterion(outputs17, labels)
                loss18 = criterion(outputs18, labels)
                loss19 = criterion(outputs19, labels)
                loss20 = criterion(outputs20, labels)
                loss21 = criterion(outputs21, labels)

                tloss1 = triplet(q1, labels)[0]
                tloss2 = triplet(q2, labels)[0]
                tloss3 = triplet(q3, labels)[0]
                tloss4 = triplet(q4, labels)[0]
                tloss5 = triplet(q5, labels)[0]
                tloss6 = triplet(q6, labels)[0]
                #
                temp_loss.append(loss1)
                temp_loss.append(loss2)
                temp_loss.append(loss3)
                temp_loss.append(loss4)
                temp_loss.append(loss5)
                temp_loss.append(loss6)
                temp_loss.append(loss7)
                temp_loss.append(loss8)
                temp_loss.append(loss9)
                temp_loss.append(loss10)
                temp_loss.append(loss11)
                temp_loss.append(loss12)
                temp_loss.append(loss13)
                temp_loss.append(loss14)
                temp_loss.append(loss15)
                temp_loss.append(loss16)
                temp_loss.append(loss17)
                temp_loss.append(loss18)
                temp_loss.append(loss19)
                temp_loss.append(loss20)
                temp_loss.append(loss21)
                loss = sum(temp_loss) / 21 + (tloss1 + tloss2 + tloss3 +
                                              tloss4 + tloss5 + tloss6) / 6
                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                running_loss += loss.item() * now_batch_size
                a = float(torch.sum(preds1 == labels.data))
                b = float(torch.sum(preds2 == labels.data))
                c = float(torch.sum(preds3 == labels.data))
                d = float(torch.sum(preds4 == labels.data))
                e = float(torch.sum(preds5 == labels.data))
                f = float(torch.sum(preds6 == labels.data))
                g = float(torch.sum(preds7 == labels.data))
                h = float(torch.sum(preds8 == labels.data))
                a9 = float(torch.sum(preds9 == labels.data))
                a10 = float(torch.sum(preds10 == labels.data))
                a11 = float(torch.sum(preds11 == labels.data))
                a12 = float(torch.sum(preds12 == labels.data))
                a13 = float(torch.sum(preds13 == labels.data))
                a14 = float(torch.sum(preds14 == labels.data))
                a15 = float(torch.sum(preds15 == labels.data))
                a16 = float(torch.sum(preds16 == labels.data))
                a17 = float(torch.sum(preds17 == labels.data))
                a18 = float(torch.sum(preds18 == labels.data))
                a19 = float(torch.sum(preds19 == labels.data))
                a20 = float(torch.sum(preds20 == labels.data))
                a21 = float(torch.sum(preds21 == labels.data))
                #
                running_corrects_1 = a + b + c + d + e + f + g + h + a9 + a10 + a11 + a12 + a13 + a14 + a15 + a16 + a17 + a18 + a19 + a20 + a21
                running_corrects_2 = running_corrects_1 / 21
                running_corrects += running_corrects_2

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects / dataset_sizes[phase]
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss,
                                                       epoch_acc))
            time_elapsed = time.time() - since
            print('Training times in {:.0f}m {:.0f}s'.format(
                time_elapsed // 60, time_elapsed % 60))

            y_loss[phase].append(epoch_loss)
            y_err[phase].append(1.0 - epoch_acc)
            # deep copy the model
            if phase == 'train':
                last_model_wts = model.state_dict()
                if epoch < 150:
                    if epoch % 10 == 9:
                        save_network(model, epoch)
                    draw_curve(epoch)
                else:
                    #if epoch%2 == 0:
                    save_network(model, epoch)
                    draw_curve(epoch)

        print()
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    # load best model weights
    model.load_state_dict(last_model_wts)
    save_network(model, 'last')
    return model
Exemple #6
0
    lambda x: torch.from_numpy(x),
    torchvision.transforms.RandomVerticalFlip()
])

train_names, val_names, train_mask, val_mask = load_data()
train_data = dataset(train_names, train_mask, transform)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
model = get_fast_scnn()
model.to(device)
criterion = MixSoftmaxCrossEntropyOHEMLoss(aux=aux, aux_weight=0.4,
                                                        ignore_index=-1).to(device)
optimizer = torch.optim.SGD(model.parameters(),
                                         lr=lr,
                                         momentum=momentum,
                                         weight_decay=weight_decay)
lr_scheduler = LRScheduler(mode='poly', base_lr=lr, nepochs=epochs,
                                        iters_per_epoch=len(train_loader), power=0.9)
        
def checkpoint(model, epoch):
    filename = 'fscnn_{}.pth'.format(epoch)
    directory = './'
    save_path = os.path.join(directory, filename)
    torch.save(model.state_dict(), save_path)

iterations = 0
start_time = time.time()
for epoch in range(epochs):
    model.train()
        
    for image, targets in train_loader:
        cur_lr = lr_scheduler(iterations)
        for param_group in optimizer.param_groups:
def train_12ECG_classifier(input_directory, output_directory):
    # Load data.
    print('Loading data...')
    header_files = []
    for f in os.listdir(input_directory):
        g = os.path.join(input_directory, f)
        if not f.lower().startswith('.') and f.lower().endswith('hea') and os.path.isfile(g):
            header_files.append(g)

    classes = sorted(['270492004', '164889003', '164890007', '426627000', '713427006', '713426002', '445118002', '39732003',
                      '164909002', '251146004', '698252002', '10370003', '284470004', '427172004', '164947007', '111975006',
                      '164917005', '47665007', '59118001', '427393009', '426177001', '426783006', '427084000', '63593006',
                      '164934002', '59931005', '17338001'])
    num_classes = len(classes)
    num_files = len(header_files)
    recordings = list()
    headers = list()

    for i in range(num_files):
        recording, header = load_challenge_data(header_files[i])
        recordings.append(recording)
        headers.append(header)

    # Train model.
    print('Training model...')

    labels = list()

    for i in range(num_files):
        header = headers[i]

        for l in header:
            if l.startswith('#Dx:'):
                labels_act = np.zeros(num_classes)
                arrs = l.strip().split(' ')
                for arr in arrs[1].split(','):
                    if arr.rstrip() in classes:
                        class_index = classes.index(arr.rstrip())  # Only use first positive index
                        labels_act[class_index] = 1
        labels.append(labels_act)

    labels = pd.DataFrame(labels, columns=classes,dtype='int')
    labels['713427006'] = labels['713427006'] | labels['59118001']
    labels['59118001'] = labels['713427006'] | labels['59118001']
    labels['284470004'] = labels['284470004'] | labels['63593006']
    labels['63593006'] = labels['284470004'] | labels['63593006']
    labels['427172004'] = labels['427172004'] | labels['17338001']
    labels['17338001'] = labels['427172004'] | labels['17338001']
    labels = np.array(labels)
    # Train the classifier
    model = ResNet34(num_classes=27).to(device)
    train_dataset = ECGDataset(recordings, labels, headers, train=True)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=False)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    niters = len(train_loader)
    lr_scheduler = LRScheduler(optimizer, niters, Config)
    net1 = train(train_loader, model, optimizer, lr_scheduler, 20)

    # Save model.
    print('Saving model...')

    torch.save(net1, output_directory + '/net1.pkl')
Exemple #8
0
                              collate_fn=train_dataset.collate_fn,
                              pin_memory=True)
    val_dataset = DataGenerator(config, val_data, phase='val')
    val_loader = DataLoader(val_dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=workers,
                            collate_fn=val_dataset.collate_fn,
                            pin_memory=True)
    optimizer = torch.optim.SGD(net.parameters(),
                                lr,
                                momentum=0.9,
                                weight_decay=1e-4)
    lrs = LRScheduler(lr,
                      patience=3,
                      factor=0.1,
                      min_lr=0.01 * lr,
                      best_loss=best_val_loss)
    for epoch in range(start_epoch, epochs + 1):
        train_metrics, train_time = train(train_loader, net, loss, optimizer,
                                          lr)
        with torch.no_grad():
            val_metrics, val_time = validate(val_loader, net, loss)

        print_log(epoch,
                  lr,
                  train_metrics,
                  train_time,
                  val_metrics,
                  val_time,
                  save_dir=save_dir,