Пример #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--epochs', default=50, type=int, help='epoch number')
    parser.add_argument('-b',
                        '--batch_size',
                        default=256,
                        type=int,
                        help='mini-batch size')
    parser.add_argument('--lr',
                        '--learning_rate',
                        default=1e-3,
                        type=float,
                        help='initial learning rate')
    parser.add_argument('-c',
                        '--continue',
                        dest='continue_path',
                        type=str,
                        required=False)
    parser.add_argument('--exp_name',
                        default=config.exp_name,
                        type=str,
                        required=False)
    parser.add_argument('--drop_rate', default=0, type=float, required=False)
    parser.add_argument('--only_fc',
                        action='store_true',
                        help='only train fc layers')
    parser.add_argument('--net',
                        default='densenet169',
                        type=str,
                        required=False)
    parser.add_argument('--local',
                        action='store_true',
                        help='train local branch')
    args = parser.parse_args()
    args.batch_size = 32
    args.epochs = 150
    args.net = 'densenet169'
    print(args)

    config.exp_name = args.exp_name
    config.make_dir()
    save_args(args, config.log_dir)

    # get network
    if args.net == 'resnet50':
        net = resnet50(pretrained=True, drop_rate=args.drop_rate)
    elif args.net == 'resnet101':
        net = resnet101(pretrained=True, drop_rate=args.drop_rate)
    elif args.net == 'densenet121':
        net = models.densenet121(pretrained=True)
        net.classifier = nn.Sequential(nn.Linear(1024, 1), nn.Sigmoid())
    elif args.net == 'densenet169':
        net = densenet169(pretrained=True, drop_rate=args.drop_rate)
    elif args.net == 'fusenet':
        global_branch = torch.load(GLOBAL_BRANCH_DIR)['net']
        local_branch = torch.load(LOCAL_BRANCH_DIR)['net']
        net = fusenet(global_branch, local_branch)
        del global_branch, local_branch
    else:
        raise NameError

    net = net.cuda()
    sess = Session(config, net=net)

    # get dataloader
    # train_loader = get_dataloaders('train', batch_size=args.batch_size,
    #                                shuffle=True, is_local=args.local)
    #
    # valid_loader = get_dataloaders('valid', batch_size=args.batch_size,
    #                                shuffle=False, is_local=args.local)
    train_loader = get_dataloaders('train',
                                   batch_size=args.batch_size,
                                   num_workers=4,
                                   shuffle=True)

    valid_loader = get_dataloaders('valid',
                                   batch_size=args.batch_size,
                                   shuffle=False)

    if args.continue_path and os.path.exists(args.continue_path):
        sess.load_checkpoint(args.continue_path)

    # start session
    clock = sess.clock
    tb_writer = sess.tb_writer
    sess.save_checkpoint('start.pth.tar')

    # set criterion, optimizer and scheduler
    criterion = nn.BCELoss().cuda()  # not used

    if args.only_fc:
        optimizer = optim.Adam(sess.net.module.classifier.parameters(),
                               args.lr)
    else:
        optimizer = optim.Adam(sess.net.parameters(), args.lr)

    scheduler = ReduceLROnPlateau(optimizer,
                                  'max',
                                  factor=0.1,
                                  patience=10,
                                  verbose=True)

    # start training
    for e in range(args.epochs):
        train_out = train_model(train_loader, sess.net, criterion, optimizer,
                                clock.epoch)
        valid_out = valid_model(valid_loader, sess.net, criterion, optimizer,
                                clock.epoch)

        tb_writer.add_scalars('loss', {
            'train': train_out['epoch_loss'],
            'valid': valid_out['epoch_loss']
        }, clock.epoch)

        tb_writer.add_scalars('acc', {
            'train': train_out['epoch_acc'],
            'valid': valid_out['epoch_acc']
        }, clock.epoch)

        tb_writer.add_scalar('auc', valid_out['epoch_auc'], clock.epoch)

        tb_writer.add_scalar('learning_rate', optimizer.param_groups[-1]['lr'],
                             clock.epoch)
        scheduler.step(valid_out['epoch_auc'])

        if valid_out['epoch_auc'] > sess.best_val_acc:
            sess.best_val_acc = valid_out['epoch_auc']
            sess.save_checkpoint('best_model.pth.tar')

        if clock.epoch % 10 == 0:
            sess.save_checkpoint('epoch{}.pth.tar'.format(clock.epoch))
        sess.save_checkpoint('latest.pth.tar')

        clock.tock()
Пример #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--epochs', default=50, type=int, help='epoch number')
    parser.add_argument('-b', '--batch_size', default=64, type=int, help='mini-batch size')
    parser.add_argument('--lr', '--learning_rate', default=1e-4, type=float, help='initial learning rate')
    parser.add_argument('-c', '--continue', dest='continue_path', type=str, required=False)
    parser.add_argument('--state_dict', default=None, type=str, required=False,
                        help='state_dict when doing full training ')
    parser.add_argument('--exp_name', default=config.exp_name, type=str, required=False)
    parser.add_argument('--drop_rate', default=0, type=float, required=False)
    parser.add_argument('--local', action='store_true', help='train local branch')
    args = parser.parse_args()
    print(args)

    config.exp_name = args.exp_name
    config.make_dir()
    save_args(args, config.log_dir)

    # get network
    if args.state_dict is not None:
        state_dict = torch.load(args.state_dict)
        net = fusenet()
        net.load_state_dict(state_dict)
        net.set_fcweights()
    else:
        global_branch_state = torch.load(GLOBAL_BRANCH_DIR)
        local_branch_state = torch.load(LOCAL_BRANCH_DIR)
        net = fusenet(global_branch_state, local_branch_state)

    net.to(config.device)
    sess = Session(config, net=net)

    # get dataloader
    train_loader = get_dataloaders('train', batch_size=args.batch_size,
                                   shuffle=True)

    valid_loader = get_dataloaders('valid', batch_size=args.batch_size,
                                   shuffle=False)

    if args.continue_path and os.path.exists(args.continue_path):
        sess.load_checkpoint(args.continue_path)

    # start session
    clock = sess.clock
    tb_writer = sess.tb_writer
    sess.save_checkpoint('start.pth.tar')

    # set criterion, optimizer and scheduler
    criterion = nn.BCELoss().cuda()

    if args.local:  # train local branch
        optimizer = optim.Adam(sess.net.module.local_branch.parameters(), args.lr)
    else:   # train final fc layer
        optimizer = optim.Adam(sess.net.classifier.parameters(), args.lr)

    scheduler = ReduceLROnPlateau(optimizer, 'max', factor=0.1,  patience=10, verbose=True)

    # start training
    for e in range(args.epochs):
        train_out = train_model(train_loader, sess.net,
                                criterion, optimizer, clock.epoch)
        valid_out = valid_model(valid_loader, sess.net,
                                criterion, optimizer, clock.epoch)

        tb_writer.add_scalars('loss', {'train': train_out['epoch_loss'],
                                       'valid': valid_out['epoch_loss']}, clock.epoch)

        tb_writer.add_scalars('acc', {'train': train_out['epoch_acc'],
                                      'valid': valid_out['epoch_acc']}, clock.epoch)

        tb_writer.add_scalar('auc', valid_out['epoch_auc'], clock.epoch)

        tb_writer.add_scalar('learning_rate', optimizer.param_groups[-1]['lr'], clock.epoch)
        scheduler.step(valid_out['epoch_auc'])

        if valid_out['epoch_auc'] > sess.best_val_acc:
            sess.best_val_acc = valid_out['epoch_auc']
            sess.save_checkpoint('best_model.pth.tar')

        if clock.epoch % 10 == 0:
            sess.save_checkpoint('epoch{}.pth.tar'.format(clock.epoch))
        sess.save_checkpoint('latest.pth.tar')

        clock.tock()
Пример #3
0
    dataloader = get_dataloaders(args.phase,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 data_dir=args.data_dir)
    total_scores = None  # voting scores
    labels = None
    st_corrects = {st: 0 for st in config.study_type}
    nr_stype = {st: 0 for st in config.study_type}

    for j in range(len(model_list)):
        print('single model ' + str(j), model_list[j])

        if 'fuse' in model_list[j]:
            state_dict = torch.load(model_list[j])['state_dict']
            net = fusenet()
            net.load_state_dict(state_dict)
            net.set_fcweights()
            net = torch.nn.DataParallel(net).cuda()

        else:
            net = torch.load(model_list[j])['net']

        score, labels, nr_stype = get_scores(net, dataloader)
        if j == 0:
            total_scores = {x: score[x] for x in score.keys()}
        else:
            total_scores = {
                x: total_scores[x] + score[x]
                for x in total_scores.keys()
            }