Exemplo n.º 1
0
def train_network(net, fold=0, model_ckpt=None):
    # train the network, allow for keyboard interrupt
    try:
        # define optimizer
        optimizer = optim.SGD(net.parameters(), lr=args.lr_max, momentum=0.9, weight_decay=args.l2)
        # get the loaders
        train_loader, valid_loader = get_data_loaders(imsize=args.imsize,
                                                      batch_size=args.batch_size,
                                                      num_folds=args.num_folds,
                                                      fold=fold)

        # training flags
        swa = False
        use_lovasz = False
        freeze_bn = False
        save_imgs = False
        train_losses = []
        valid_losses = []
        valid_ious = []

        lr_patience = 0
        valid_patience = 0
        best_val_metric = 1000.0
        best_val_iou = 0.0
        cycle = 0
        swa_n = 0
        t_ = 0
        
        print('args.swa', args.swa)
        print('args.lr_rampdown', args.lr_rampdown)
        print('args.use_lovasz', args.use_lovasz)

        print('Training ...')
        for e in range(args.epochs):
            print('\n' + 'Epoch {}/{}'.format(e, args.epochs))

            start = time.time()
           
            for params in optimizer.param_groups:
                print('Learning rate set to {:.4}'.format(optimizer.param_groups[0]['lr']))

            t_l = train(net, optimizer, train_loader, freeze_bn, use_lovasz)
            v_l, viou = valid(net, optimizer, valid_loader, use_lovasz, save_imgs, fold)

            #if swa:
               
            # save the model on best validation loss
            #if not args.cos_anneal:
            
            if viou >= best_val_iou:
                print('new best_val_iou', viou)
                net.eval()
                torch.save(net.state_dict(), model_ckpt)
                best_val_metric = v_l
                best_val_iou = viou
                valid_patience = 0
                lr_patience = 0
                
                checkpath = '../model_weights/best_{}_{}_fold{}.pth'.format(args.model_name, args.exp_name, fold)
            else:
                print('patience', valid_patience)
                valid_patience += 1
                lr_patience += 1

            # if the model stops improving by a certain num epoch, stop
            if valid_patience > 15:
                cycle += 1
                
                checkpath = '../model_weights/{}_{}_cycle-{}_fold-{}.pth'.format(args.model_name, args.exp_name,cycle,fold)
                print('save model', checkpath)
                torch.save(net.state_dict(),checkpath)
                
                print('cycle', cycle, 'num_cycles', args.num_cycles)
                if cycle >= args.num_cycles:
                    print('all over')
                    break
                
                print('rampdown')
                for params in optimizer.param_groups:
                    params['lr'] = (args.lr_min + 0.5 * (args.lr_max - args.lr_min) *
                                   (1 + np.cos(np.pi * t_ / args.lr_rampdown)))
                
                print('Learning rate set to {:.4}'.format(optimizer.param_groups[0]['lr']))
            elif lr_patience > 5:
                print('Reducing learning rate by {}'.format(0.5))
                for params in optimizer.param_groups:
                    params['lr'] *= 0.5
                lr_patience = 0
            
            # if the model doesn't improve for n epochs, reduce learning rate
            if cycle >=  1: 
                print('switching to lovasz')
                use_lovasz = True
            
            train_losses.append(t_l)
            valid_losses.append(v_l)
            valid_ious.append(viou)

            t_ += 1
            print('Time: {}'.format(time.time()-start))

    except KeyboardInterrupt:
        pass

    if args.swa:
        for i in range(cycle):
            if i == 0:
                net.load_state_dict(torch.load('../swa/cycle_{}.pth'.format(i),
		                                 map_location=lambda storage, loc: storage))
            else:
                alpha = 1. / (i + 1.)
                prev = ResUNet()
                prev.load_state_dict(torch.load('../swa/cycle_{}.pth'.format(i),
		                                 map_location=lambda storage, loc: storage))
	        # average weights
                for param_c, param_p in zip(net.parameters(), prev.parameters()):
                    param_c.data *= (1.0 - alpha)
                    param_c.data += param_p.data.to(device) * alpha

        bn_update(train_loader, net, args.gpu)

    net.eval()
    torch.save(net.state_dict(), '../model_weights/swa_{}_{}_fold-{}.pth'.format(args.model_name,args.exp_name, fold))

    import pandas as pd

    out_dict = {'train_losses': train_losses,
                'valid_losses': valid_losses,
                'valid_ious': valid_ious}

    out_log = pd.DataFrame(out_dict)
    out_log.to_csv('../logs/resunet_fold-{}.csv'.format(fold), index=False)

    return best_val_iou
Exemplo n.º 2
0
def train_network(net, fold=0, model_ckpt=None):
    # train the network, allow for keyboard interrupt
    try:
        # define optimizer
        optimizer = optim.SGD(net.parameters(),
                              lr=args.lr_max,
                              momentum=0.9,
                              weight_decay=args.l2)
        # get the loaders
        train_loader, valid_loader = get_data_loaders(
            imsize=args.imsize,
            batch_size=args.batch_size,
            num_folds=args.num_folds,
            fold=fold)

        # training flags
        swa = False
        use_lovasz = False
        freeze_bn = False
        save_imgs = False
        train_losses = []
        valid_losses = []
        valid_ious = []

        valid_patience = 0
        best_val_metric = 1000.0
        best_val_iou = 0.0
        cycle = 0
        swa_n = 0
        t_ = 0

        print('Training ...')
        for e in range(args.epochs):
            print('\n' + 'Epoch {}/{}'.format(e, args.epochs))

            start = time.time()

            # LR warm-up
            #if e < args.lr_rampup:
            #    lr = args.lr_max * (min(e, args.lr_rampup) / args.lr_rampup)

            # if we get to the end of lr period, save swa weights
            if t_ >= args.lr_rampdown:
                # if we are using swa save off the current weights before updating
                if args.swa:
                    torch.save(net.state_dict(),
                               '../swa/cycle_{}.pth'.format(cycle))
                    #swa_n += 1
                # reset the counter
                t_ = 0
                cycle += 1
                torch.save(
                    net.state_dict(),
                    '../model_weights/{}_{}_cycle-{}_fold-{}.pth'.format(
                        args.model_name, args.exp_name, cycle, fold))
                save_imgs = True
            else:
                save_imgs = False

            for params in optimizer.param_groups:
                #print('t_', t_)
                if args.cos_anneal and e > args.lr_rampup:
                    params['lr'] = (
                        args.lr_min + 0.5 * (args.lr_max - args.lr_min) *
                        (1 + np.cos(np.pi * t_ / args.lr_rampdown)))
                elif e < args.lr_rampup:
                    params['lr'] = args.lr_max * (min(t_ + 1, args.lr_rampup) /
                                                  args.lr_rampup)

                print('Learning rate set to {:.4}'.format(
                    optimizer.param_groups[0]['lr']))

            t_l = train(net, optimizer, train_loader, freeze_bn, use_lovasz)
            v_l, viou = valid(net, optimizer, valid_loader, use_lovasz,
                              save_imgs, fold)

            #if swa:

            # save the model on best validation loss
            #if not args.cos_anneal:
            if viou > best_val_iou:
                net.eval()
                torch.save(net.state_dict(), model_ckpt)
                best_val_metric = v_l
                best_val_iou = viou
                valid_patience = 0
            else:
                valid_patience += 1

            # if the model stops improving by a certain num epoch, stop
            if cycle >= args.num_cycles:
                break

            # if the model doesn't improve for n epochs, reduce learning rate
            if cycle >= 1:
                if args.use_lovasz:
                    print('switching to lovasz')
                    use_lovasz = True

                #dice_weight += 0.5
                if not args.cos_anneal:
                    print('Reducing learning rate by {}'.format(args.lr_scale))
                    for params in optimizer.param_groups:
                        params['lr'] *= args.lr_scale

            train_losses.append(t_l)
            valid_losses.append(v_l)
            valid_ious.append(viou)

            #if e in LR_SCHED:
            #    print('Reducing learning rate by {}'.format(args.lr_scale))
            #    for params in optimizer.param_groups:
            #        params['lr'] *= args.lr_scale

            t_ += 1
            print('Time: {}'.format(time.time() - start))

    except KeyboardInterrupt:
        pass

    if args.swa:
        for i in range(cycle):
            if i == 0:
                net.load_state_dict(
                    torch.load('../swa/cycle_{}.pth'.format(i),
                               map_location=lambda storage, loc: storage))
            else:
                alpha = 1. / (i + 1.)
                prev = ResUNet()
                prev.load_state_dict(
                    torch.load('../swa/cycle_{}.pth'.format(i),
                               map_location=lambda storage, loc: storage))
                # average weights
                for param_c, param_p in zip(net.parameters(),
                                            prev.parameters()):
                    param_c.data *= (1.0 - alpha)
                    param_c.data += param_p.data.to(device) * alpha

        bn_update(train_loader, net, args.gpu)

    net.eval()
    torch.save(
        net.state_dict(), '../model_weights/swa_{}_{}_fold-{}.pth'.format(
            args.model_name, args.exp_name, fold))

    import pandas as pd

    out_dict = {
        'train_losses': train_losses,
        'valid_losses': valid_losses,
        'valid_ious': valid_ious
    }

    out_log = pd.DataFrame(out_dict)
    out_log.to_csv('../logs/resunet_fold-{}.csv'.format(fold), index=False)

    return best_val_iou