Пример #1
0
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
    # import ipdb; ipdb.set_trace()
    if not os.path.exists('./checkpoint'):
        os.makedirs('checkpoint')
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    checkpoint_path = './checkpoint/chkpoint_colab_{}.pt'.format(epoch)

    lr_scheduler = None
    if epoch == 0:
        warmup_factor = 1. / 1000
        warmup_iters = min(1000, len(data_loader) - 1)

        lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)
    total_loss = 0.0

    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        # import ipdb; ipdb.set_trace()
        losses = sum(loss for loss in loss_dict.values())

        #option
        total_loss = total_loss + losses

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())

        loss_value = losses_reduced.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            print(loss_dict_reduced)
            sys.exit(1)

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        if lr_scheduler is not None:
            lr_scheduler.step()

        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
    tb.add_scalar('Train Loss', total_loss, epoch)
    checkpoint = {
        'epoch': epoch + 1,
        'train_loss_min': losses,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }
    utils.save_ckp(checkpoint, False, checkpoint_path, None)
Пример #2
0
def train_model(model, dataloaders, criterion, optimizer, scheduler, dataset_sizes, checkpoint_path, num_epochs=25):
    print(f"saving to {checkpoint_path}")
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0
    loss_p = {'train':[],'val':[]}
    acc_p = {'train':[],'val':[]}

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            tk = tqdm(dataloaders[phase], total=len(dataloaders[phase]))
            for inputs, labels in tk:
                inputs = inputs.to(config.DEVICE)
                labels = labels.to(config.DEVICE)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        # loss.backward()
                        # optimizer.step()
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()

                # torch.cuda.empty_cache()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

                # print("running loss ",running_loss)

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            loss_p[phase].append(epoch_loss)
            acc_p[phase].append(epoch_acc)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                checkpoint = {
                    'epoch': epoch,
                    'valid_acc': best_acc,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }
                # checkpoint_path = "/content/drive/MyDrive/competitions/mosaic-r1/weights/res18.pt"
                print(f"saving to {checkpoint_path}")
                save_ckp(checkpoint, checkpoint_path)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)

    plot(loss_p,acc_p,num_epochs)

    return model, best_acc
Пример #3
0
def main(ckp_path=None):
    """ckp_path (str): checkpoint_path
    Train the model from scratch if ckp_path is None else
    Re-Train the model from previous checkpoint
    """
    cli_args = get_train_args(__author__, __version__)

    # Variables
    data_dir = cli_args.data_dir
    save_dir = cli_args.save_dir
    file_name = cli_args.file_name
    use_gpu = cli_args.use_gpu

    # LOAD DATA
    data_loaders = load_data(data_dir, config.IMG_SIZE, config.BATCH_SIZE)

    # BUILD MODEL
    if ckp_path == None:
        model = initialize_model(model_name=config.MODEL_NAME,
                                 num_classes=config.NO_OF_CLASSES,
                                 feature_extract=True,
                                 use_pretrained=True)
    else:
        model = load_ckp(ckp_path)

    # Device is available or not
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # If the user wants the gpu mode, check if cuda is available
    if (use_gpu == True) and (torch.cuda.is_available() == False):
        print("GPU mode is not available, using CPU...")
        use_gpu = False

    # MOVE MODEL TO AVAILBALE DEVICE
    model.to(device)

    # DEFINE OPTIMIZER
    optimizer = optimizer_fn(model_name=config.MODEL_NAME,
                             model=model,
                             lr_rate=config.LR_RATE)

    # DEFINE SCHEDULER
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode="min",
                                                           patience=5,
                                                           factor=0.3,
                                                           verbose=True)

    # DEFINE LOSS FUNCTION
    criterion = loss_fn()

    # LOAD BEST MODEL'S WEIGHTS
    best_model_wts = copy.deepcopy(model.state_dict())

    # BEST VALIDATION SCORE
    if ckp_path == None:
        best_score = -1  # IF MODEL IS TRAIN FROM SCRATCH
    else:
        best_score = model.best_score  # IF MODEL IS RE-TRAIN

    # NO OF ITERATION
    no_epochs = config.EPOCHS
    # KEEP TRACK OF LOSS AND ACCURACY IN EACH EPOCH
    stats = {
        'train_losses': [],
        'valid_losses': [],
        'train_accuracies': [],
        'valid_accuracies': []
    }

    print("Models's Training Start......")

    for epoch in range(1, no_epochs + 1):
        train_loss, train_score = train_fn(data_loaders,
                                           model,
                                           optimizer,
                                           criterion,
                                           device,
                                           phase='train')
        val_loss, val_score = eval_fn(data_loaders,
                                      model,
                                      criterion,
                                      device=config.DEVICE,
                                      phase='valid')
        scheduler.step(val_loss)

        # SAVE MODEL'S WEIGHTS IF MODEL' VALIDATION ACCURACY IS INCREASED
        if val_score > best_score:
            print(
                'Validation score increased ({:.6f} --> {:.6f}).  Saving model ...'
                .format(best_score, val_score))
            best_score = val_score
            best_model_wts = copy.deepcopy(
                model.state_dict())  #Saving the best model' weights

        # MAKE A RECORD OF AVERAGE LOSSES AND ACCURACY IN EACH EPOCH FOR PLOTING
        stats['train_losses'].append(train_loss)
        stats['valid_losses'].append(val_loss)
        stats['train_accuracies'].append(train_score)
        stats['valid_accuracies'].append(val_score)

        # PRINT TRAINING AND VALIDATION LOOS/ACCURACIES AFTER EACH EPOCH
        epoch_len = len(str(no_epochs))
        print_msg = (f'[{epoch:>{epoch_len}}/{no_epochs:>{epoch_len}}] ' +
                     '\t' + f'train_loss: {train_loss:.5f} ' + '\t' +
                     f'train_score: {train_score:.5f} ' + '\t' +
                     f'valid_loss: {val_loss:.5f} ' + '\t' +
                     f'valid_score: {val_score:.5f}')
        print(print_msg)

    # load best model weights
    model.load_state_dict(best_model_wts)

    # create checkpoint variable and add important data
    model.class_to_idx = data_loaders['train'].dataset.class_to_idx
    model.best_score = best_score
    model.model_name = config.MODEL_NAME
    checkpoint = {
        'epoch': no_epochs,
        'lr_rate': config.LR_RATE,
        'model_name': config.MODEL_NAME,
        'batch_size': config.BATCH_SIZE,
        'valid_score': best_score,
        'optimizer': optimizer.state_dict(),
        'state_dict': model.state_dict(),
        'class_to_idx': model.class_to_idx
    }

    # SAVE CHECKPOINT
    save_ckp(checkpoint, save_dir, file_name)

    print("Models's Training is Successfull......")

    return model
Пример #4
0
def train(args):

    torch.cuda.manual_seed(1)
    torch.manual_seed(1)

    # user defined parameters
    model_name = args.model_name
    model_type = args.model_type
    lstm_backbone = args.lstmbase
    unet_backbone = args.unetbase
    layer_num = args.layer_num
    nb_shortcut = args.nb_shortcut
    loss_fn = args.loss_fn
    world_size = args.world_size
    rank = args.rank
    base_channel = args.base_channels
    crop_size = args.crop_size
    ignore_idx = args.ignore_idx
    return_sequence = args.return_sequence
    variant = args.LSTM_variant
    epochs = args.epoch
    is_pretrain = args.is_pretrain

    # system setup parameters
    config_file = 'config.yaml'
    config = load_config(config_file)
    labels = config['PARAMETERS']['labels']
    root_path = config['PATH']['model_root']
    model_dir = config['PATH']['save_ckp']
    best_dir = config['PATH']['save_best_model']

    input_modalites = int(config['PARAMETERS']['input_modalites'])
    output_channels = int(config['PARAMETERS']['output_channels'])
    batch_size = int(config['PARAMETERS']['batch_size'])
    is_best = bool(config['PARAMETERS']['is_best'])
    is_resume = bool(config['PARAMETERS']['resume'])
    patience = int(config['PARAMETERS']['patience'])
    time_step = int(config['PARAMETERS']['time_step'])
    num_workers = int(config['PARAMETERS']['num_workers'])
    early_stop_patience = int(config['PARAMETERS']['early_stop_patience'])
    lr = int(config['PARAMETERS']['lr'])
    optimizer = config['PARAMETERS']['optimizer']
    connect = config['PARAMETERS']['connect']
    conv_type = config['PARAMETERS']['lstm_convtype']

    # build up dirs
    model_path = os.path.join(root_path, model_dir)
    best_path = os.path.join(root_path, best_dir)
    intermidiate_data_save = os.path.join(root_path, 'train_newdata',
                                          model_name)
    train_info_file = os.path.join(intermidiate_data_save,
                                   '{}_train_info.json'.format(model_name))
    log_path = os.path.join(root_path, 'logfiles')

    if not os.path.exists(model_path):
        os.mkdir(model_path)
    if not os.path.exists(best_path):
        os.mkdir(best_path)
    if not os.path.exists(intermidiate_data_save):
        os.makedirs(intermidiate_data_save)
    if not os.path.exists(log_path):
        os.mkdir(log_path)

    log_name = model_name + '_' + config['PATH']['log_file']
    logger = logfile(os.path.join(log_path, log_name))
    logger.info('labels {} are ignored'.format(ignore_idx))
    logger.info('Dataset is loading ...')
    writer = SummaryWriter('ProcessVisu/%s' % model_name)

    # load training set and validation set
    data_class = data_split()
    train, val, test = data_construction(data_class)
    train_dict = time_parser(train, time_patch=time_step)
    val_dict = time_parser(val, time_patch=time_step)

    # LSTM initilization

    if model_type == 'LSTM':
        net = LSTMSegNet(lstm_backbone=lstm_backbone,
                         input_dim=input_modalites,
                         output_dim=output_channels,
                         hidden_dim=base_channel,
                         kernel_size=3,
                         num_layers=layer_num,
                         conv_type=conv_type,
                         return_sequence=return_sequence)
    elif model_type == 'UNet_LSTM':
        if variant == 'back':
            net = BackLSTM(input_dim=input_modalites,
                           hidden_dim=base_channel,
                           output_dim=output_channels,
                           kernel_size=3,
                           num_layers=layer_num,
                           conv_type=conv_type,
                           lstm_backbone=lstm_backbone,
                           unet_module=unet_backbone,
                           base_channel=base_channel,
                           return_sequence=return_sequence,
                           is_pretrain=is_pretrain)
            logger.info(
                'the pretrained status of backbone is {}'.format(is_pretrain))
        elif variant == 'center':
            net = CenterLSTM(input_modalites=input_modalites,
                             output_channels=output_channels,
                             base_channel=base_channel,
                             num_layers=layer_num,
                             conv_type=conv_type,
                             return_sequence=return_sequence,
                             is_pretrain=is_pretrain)
        elif variant == 'bicenter':
            net = BiCenterLSTM(input_modalites=input_modalites,
                               output_channels=output_channels,
                               base_channel=base_channel,
                               num_layers=layer_num,
                               connect=connect,
                               conv_type=conv_type,
                               return_sequence=return_sequence,
                               is_pretrain=is_pretrain)
        elif variant == 'directcenter':
            net = DirectCenterLSTM(input_modalites=input_modalites,
                                   output_channels=output_channels,
                                   base_channel=base_channel,
                                   num_layers=layer_num,
                                   conv_type=conv_type,
                                   return_sequence=return_sequence,
                                   is_pretrain=is_pretrain)
        elif variant == 'bidirectcenter':
            net = BiDirectCenterLSTM(input_modalites=input_modalites,
                                     output_channels=output_channels,
                                     base_channel=base_channel,
                                     num_layers=layer_num,
                                     connect=connect,
                                     conv_type=conv_type,
                                     return_sequence=return_sequence,
                                     is_pretrain=is_pretrain)
        elif variant == 'rescenter':
            net = ResCenterLSTM(input_modalites=input_modalites,
                                output_channels=output_channels,
                                base_channel=base_channel,
                                num_layers=layer_num,
                                conv_type=conv_type,
                                return_sequence=return_sequence,
                                is_pretrain=is_pretrain)
        elif variant == 'birescenter':
            net = BiResCenterLSTM(input_modalites=input_modalites,
                                  output_channels=output_channels,
                                  base_channel=base_channel,
                                  num_layers=layer_num,
                                  connect=connect,
                                  conv_type=conv_type,
                                  return_sequence=return_sequence,
                                  is_pretrain=is_pretrain)
        elif variant == 'shortcut':
            net = ShortcutLSTM(input_modalites=input_modalites,
                               output_channels=output_channels,
                               base_channel=base_channel,
                               num_layers=layer_num,
                               num_connects=nb_shortcut,
                               conv_type=conv_type,
                               return_sequence=return_sequence,
                               is_pretrain=is_pretrain)
    else:
        raise NotImplementedError()

    # loss and optimizer setup
    if loss_fn == 'Dice':
        criterion = DiceLoss(labels=labels, ignore_idx=ignore_idx)
    elif loss_fn == 'GDice':
        criterion = GneralizedDiceLoss(labels=labels)
    elif loss_fn == 'WCE':
        criterion = WeightedCrossEntropyLoss(labels=labels)
    else:
        raise NotImplementedError()

    if optimizer == 'adam':
        optimizer = optim.Adam(net.parameters(), lr=0.001)
        # optimizer = optim.Adam(net.parameters())
    elif optimizer == 'sgd':
        optimizer = optim.SGD(net.parameters(), momentum=0.9, lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     verbose=True,
                                                     patience=patience)

    # device setup
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    # net, optimizer = amp.initialize(net, optimizer, opt_level="O1")

    if torch.cuda.device_count() > 1:
        torch.distributed.init_process_group(
            backend='nccl',
            init_method='tcp://127.0.0.1:38366',
            rank=rank,
            world_size=world_size)
    if distributed_is_initialized():
        print('distributed is initialized')
        net.to(device)
        net = nn.parallel.DistributedDataParallel(net,
                                                  find_unused_parameters=True)
    else:
        print('data parallel')
        net = nn.DataParallel(net)
        net.to(device)

    min_loss = float('Inf')
    early_stop_count = 0
    global_step = 0
    start_epoch = 0
    start_loss = 0
    train_info = {
        'train_loss': [],
        'val_loss': [],
        'label_0_acc': [],
        'label_1_acc': [],
        'label_2_acc': [],
        'label_3_acc': [],
        'label_4_acc': []
    }

    if is_resume:
        try:
            # open previous check points
            ckp_path = os.path.join(model_path,
                                    '{}_model_ckp.pth.tar'.format(model_name))
            net, optimizer, scheduler, start_epoch, min_loss, start_loss = load_ckp(
                ckp_path, net, optimizer, scheduler)

            # open previous training records
            with open(train_info_file) as f:
                train_info = json.load(f)

            logger.info(
                'Training loss from last time is {}'.format(start_loss) +
                '\n' +
                'Mininum training loss from last time is {}'.format(min_loss))
            logger.info(
                'Training accuracies from last time are: label 0: {}, label 1: {}, label 2: {}, label 3: {}, label 4: {}'
                .format(train_info['label_0_acc'][-1],
                        train_info['label_1_acc'][-1],
                        train_info['label_2_acc'][-1],
                        train_info['label_3_acc'][-1],
                        train_info['label_4_acc'][-1]))

        except:
            logger.warning(
                'No checkpoint available, strat training from scratch')

    for epoch in range(start_epoch, epochs):

        train_set = data_loader(train_dict,
                                batch_size=batch_size,
                                key='train',
                                num_works=num_workers,
                                time_step=time_step,
                                patch=crop_size,
                                model_type='RNN')
        n_train = len(train_set)

        val_set = data_loader(val_dict,
                              batch_size=batch_size,
                              key='val',
                              num_works=num_workers,
                              time_step=time_step,
                              patch=crop_size,
                              model_type='CNN')
        n_val = len(val_set)

        logger.info('Dataset loading finished!')

        nb_batches = np.ceil(n_train / batch_size)
        n_total = n_train + n_val
        logger.info(
            '{} images will be used in total, {} for trainning and {} for validation'
            .format(n_total, n_train, n_val))

        train_loader = train_set.load()

        # setup to train mode
        net.train()
        running_loss = 0
        dice_score_label_0 = 0
        dice_score_label_1 = 0
        dice_score_label_2 = 0
        dice_score_label_3 = 0
        dice_score_label_4 = 0

        logger.info('Training epoch {} will begin'.format(epoch + 1))

        with tqdm(total=n_train,
                  desc=f'Epoch {epoch+1}/{epochs}',
                  unit='patch') as pbar:

            for i, data in enumerate(train_loader, 0):

                # i : patient
                images, segs = data['image'].to(device), data['seg'].to(device)

                outputs = net(images)
                loss = criterion(outputs, segs)
                loss.backward()
                optimizer.step()

                # if i == 0:
                #     in_images = images.detach().cpu().numpy()[0]
                #     in_segs = segs.detach().cpu().numpy()[0]
                #     in_pred = outputs.detach().cpu().numpy()[0]
                #     heatmap_plot(image=in_images, mask=in_segs, pred=in_pred, name=model_name, epoch=epoch+1, is_train=True)

                running_loss += loss.detach().item()

                outputs = outputs.view(-1, outputs.shape[-4],
                                       outputs.shape[-3], outputs.shape[-2],
                                       outputs.shape[-1])
                segs = segs.view(-1, segs.shape[-3], segs.shape[-2],
                                 segs.shape[-1])
                _, preds = torch.max(outputs.data, 1)
                dice_score = dice(preds.data.cpu(),
                                  segs.data.cpu(),
                                  ignore_idx=None)

                dice_score_label_0 += dice_score['bg']
                dice_score_label_1 += dice_score['csf']
                dice_score_label_2 += dice_score['gm']
                dice_score_label_3 += dice_score['wm']
                dice_score_label_4 += dice_score['tm']

                # show progress bar
                pbar.set_postfix(
                    **{
                        'training loss': loss.detach().item(),
                        'Training accuracy': dice_score['avg']
                    })
                pbar.update(images.shape[0])

                global_step += 1
                if global_step % nb_batches == 0:
                    net.eval()
                    val_loss, val_acc, val_info = validation(net,
                                                             val_set,
                                                             criterion,
                                                             device,
                                                             batch_size,
                                                             ignore_idx=None,
                                                             name=model_name,
                                                             epoch=epoch + 1)
                    net.train()

        train_info['train_loss'].append(running_loss / nb_batches)
        train_info['val_loss'].append(val_loss)
        train_info['label_0_acc'].append(dice_score_label_0 / nb_batches)
        train_info['label_1_acc'].append(dice_score_label_1 / nb_batches)
        train_info['label_2_acc'].append(dice_score_label_2 / nb_batches)
        train_info['label_3_acc'].append(dice_score_label_3 / nb_batches)
        train_info['label_4_acc'].append(dice_score_label_4 / nb_batches)

        # save bast trained model
        scheduler.step(running_loss / nb_batches)
        logger.info('Epoch: {}, LR: {}'.format(
            epoch + 1, optimizer.param_groups[0]['lr']))

        if min_loss > running_loss / nb_batches:
            min_loss = running_loss / nb_batches
            is_best = True
            early_stop_count = 0
        else:
            is_best = False
            early_stop_count += 1

        state = {
            'epoch': epoch + 1,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': running_loss / nb_batches,
            'min_loss': min_loss
        }
        verbose = save_ckp(state,
                           is_best,
                           early_stop_count=early_stop_count,
                           early_stop_patience=early_stop_patience,
                           save_model_dir=model_path,
                           best_dir=best_path,
                           name=model_name)

        # summarize the training results of this epoch
        logger.info('The average training loss for this epoch is {}'.format(
            running_loss / nb_batches))
        logger.info('The best training loss till now is {}'.format(min_loss))
        logger.info(
            'Validation dice loss: {}; Validation (avg) accuracy of the last timestep: {}'
            .format(val_loss, val_acc))

        # save the training info every epoch
        logger.info('Writing the training info into file ...')
        val_info_file = os.path.join(intermidiate_data_save,
                                     '{}_val_info.json'.format(model_name))
        with open(train_info_file, 'w') as fp:
            json.dump(train_info, fp)
        with open(val_info_file, 'w') as fp:
            json.dump(val_info, fp)

        for name, layer in net.named_parameters():
            if layer.requires_grad:
                writer.add_histogram(name + '_grad',
                                     layer.grad.cpu().data.numpy(), epoch)
                writer.add_histogram(name + '_data',
                                     layer.cpu().data.numpy(), epoch)
        if verbose:
            logger.info(
                'The validation loss has not improved for {} epochs, training will stop here.'
                .format(early_stop_patience))
            break

    loss_plot(train_info_file, name=model_name)
    logger.info('finish training!')

    return
def train(pretrained=True):
    train_df, val_df = utils.process_csv(train_csv)

    train_set = utils.Wheatset(train_df, train_dir, phase='train')
    val_set = utils.Wheatset(val_df, train_dir, phase='validation')

    # batching
    def collate_fn(batch):
        return tuple(zip(*batch))

    train_data_loader = DataLoader(train_set,
                                   batch_size=8,
                                   shuffle=False,
                                   num_workers=2,
                                   collate_fn=collate_fn)

    valid_data_loader = DataLoader(val_set,
                                   batch_size=8,
                                   shuffle=False,
                                   num_workers=2,
                                   collate_fn=collate_fn)

    # images, targets, ids = next(iter(train_data_loader))
    # images = list(image.to(device) for image in images)
    # targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

    # construct fasterrcnn network
    model = models.construct_models()
    if pretrained:
        WEIGHTS_FILE = '/checkpoints/bestmodel_may28.pt'
        weights = torch.load(WEIGHTS_FILE)
        model.load_state_dict(weights['state_dict'])

    model.to(device)
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params,
                                lr=0.005,
                                momentum=0.9,
                                weight_decay=0.0005)

    #train
    num_epochs = 5
    train_loss_min = 0.9
    total_train_loss = []

    checkpoint_path = '/checkpoints/chkpoint_'
    best_model_path = '/checkpoints/bestmodel_may28.pt'

    for epoch in range(num_epochs):
        print(f'Epoch :{epoch + 1}')
        start_time = time.time()
        train_loss = []
        model.train()
        for images, targets, image_ids in train_data_loader:
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device)
                        for k, v in t.items()} for t in targets]

            loss_dict = model(images, targets)

            losses = sum(loss for loss in loss_dict.values())
            train_loss.append(losses.item())
            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
        # train_loss/len(train_data_loader.dataset)
        epoch_train_loss = np.mean(train_loss)
        total_train_loss.append(epoch_train_loss)
        print(f'Epoch train loss is {epoch_train_loss}')

        #     if lr_scheduler is not None:
        #         lr_scheduler.step()

        # create checkpoint variable and add important data
        checkpoint = {
            'epoch': epoch + 1,
            'train_loss_min': epoch_train_loss,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }

        # save checkpoint
        utils.save_ckp(checkpoint, False, checkpoint_path, best_model_path)
        ## TODO: save the model if validation loss has decreased
        if epoch_train_loss <= train_loss_min:
            print(
                'Train loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.
                format(train_loss_min, epoch_train_loss))
            # save checkpoint as best model
            utils.save_ckp(checkpoint, True, checkpoint_path, best_model_path)
            train_loss_min = epoch_train_loss

        time_elapsed = time.time() - start_time
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
Пример #6
0
def train(args):

    torch.cuda.manual_seed(1)
    torch.manual_seed(1)

    # user defined
    model_name = args.model_name
    model_loss_fn = args.loss_fn

    config_file = 'config.yaml'

    config = load_config(config_file)
    data_root = config['PATH']['data_root']
    labels = config['PARAMETERS']['labels']
    root_path = config['PATH']['root']
    model_dir = config['PATH']['model_path']
    best_dir = config['PATH']['best_model_path']

    data_class = config['PATH']['data_class']
    input_modalites = int(config['PARAMETERS']['input_modalites'])
    output_channels = int(config['PARAMETERS']['output_channels'])
    base_channel = int(config['PARAMETERS']['base_channels'])
    crop_size = int(config['PARAMETERS']['crop_size'])
    batch_size = int(config['PARAMETERS']['batch_size'])
    epochs = int(config['PARAMETERS']['epoch'])
    is_best = bool(config['PARAMETERS']['is_best'])
    is_resume = bool(config['PARAMETERS']['resume'])
    patience = int(config['PARAMETERS']['patience'])
    ignore_idx = int(config['PARAMETERS']['ignore_index'])
    early_stop_patience = int(config['PARAMETERS']['early_stop_patience'])

    # build up dirs
    model_path = os.path.join(root_path, model_dir)
    best_path = os.path.join(root_path, best_dir)
    intermidiate_data_save = os.path.join(root_path, 'train_data', model_name)
    train_info_file = os.path.join(intermidiate_data_save,
                                   '{}_train_info.json'.format(model_name))
    log_path = os.path.join(root_path, 'logfiles')

    if not os.path.exists(model_path):
        os.mkdir(model_path)
    if not os.path.exists(best_path):
        os.mkdir(best_path)
    if not os.path.exists(intermidiate_data_save):
        os.makedirs(intermidiate_data_save)
    if not os.path.exists(log_path):
        os.mkdir(log_path)

    log_name = model_name + '_' + config['PATH']['log_file']
    logger = logfile(os.path.join(log_path, log_name))
    logger.info('Dataset is loading ...')
    # split dataset
    dir_ = os.path.join(data_root, data_class)
    data_content = train_split(dir_)

    # load training set and validation set
    train_set = data_loader(data_content=data_content,
                            key='train',
                            form='LGG',
                            crop_size=crop_size,
                            batch_size=batch_size,
                            num_works=8)
    n_train = len(train_set)
    train_loader = train_set.load()

    val_set = data_loader(data_content=data_content,
                          key='val',
                          form='LGG',
                          crop_size=crop_size,
                          batch_size=batch_size,
                          num_works=8)

    logger.info('Dataset loading finished!')

    n_val = len(val_set)
    nb_batches = np.ceil(n_train / batch_size)
    n_total = n_train + n_val
    logger.info(
        '{} images will be used in total, {} for trainning and {} for validation'
        .format(n_total, n_train, n_val))

    net = init_U_Net(input_modalites, output_channels, base_channel)

    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    if torch.cuda.device_count() > 1:
        logger.info('{} GPUs available.'.format(torch.cuda.device_count()))
        net = nn.DataParallel(net)

    net.to(device)

    if model_loss_fn == 'Dice':
        criterion = DiceLoss(labels=labels, ignore_index=ignore_idx)
    elif model_loss_fn == 'CrossEntropy':
        criterion = CrossEntropyLoss(labels=labels, ignore_index=ignore_idx)
    elif model_loss_fn == 'FocalLoss':
        criterion = FocalLoss(labels=labels, ignore_index=ignore_idx)
    elif model_loss_fn == 'Dice_CE':
        criterion = Dice_CE(labels=labels, ignore_index=ignore_idx)
    elif model_loss_fn == 'Dice_FL':
        criterion = Dice_FL(labels=labels, ignore_index=ignore_idx)
    else:
        raise NotImplementedError()

    optimizer = optim.Adam(net.parameters(), lr=1e-4, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     verbose=True,
                                                     patience=patience)

    net, optimizer = amp.initialize(net, optimizer, opt_level='O1')

    min_loss = float('Inf')
    early_stop_count = 0
    global_step = 0
    start_epoch = 0
    start_loss = 0
    train_info = {
        'train_loss': [],
        'val_loss': [],
        'BG_acc': [],
        'NET_acc': [],
        'ED_acc': [],
        'ET_acc': []
    }

    if is_resume:
        try:
            ckp_path = os.path.join(model_dir,
                                    '{}_model_ckp.pth.tar'.format(model_name))
            net, optimizer, scheduler, start_epoch, min_loss, start_loss = load_ckp(
                ckp_path, net, optimizer, scheduler)

            # open previous training records
            with open(train_info_file) as f:
                train_info = json.load(f)

            logger.info(
                'Training loss from last time is {}'.format(start_loss) +
                '\n' +
                'Mininum training loss from last time is {}'.format(min_loss))

        except:
            logger.warning(
                'No checkpoint available, strat training from scratch')

    # start training
    for epoch in range(start_epoch, epochs):

        # setup to train mode
        net.train()
        running_loss = 0
        dice_coeff_bg = 0
        dice_coeff_net = 0
        dice_coeff_ed = 0
        dice_coeff_et = 0

        logger.info('Training epoch {} will begin'.format(epoch + 1))

        with tqdm(total=n_train,
                  desc=f'Epoch {epoch+1}/{epochs}',
                  unit='patch') as pbar:

            for i, data in enumerate(train_loader, 0):
                images, segs = data['image'].to(device), data['seg'].to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                outputs = net(images)

                loss = criterion(outputs, segs)
                # loss.backward()
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()

                optimizer.step()

                # save the output at the begining of each epoch to visulize it
                if i == 0:
                    in_images = images.detach().cpu().numpy()[:, 0, ...]
                    in_segs = segs.detach().cpu().numpy()
                    in_pred = outputs.detach().cpu().numpy()
                    heatmap_plot(image=in_images,
                                 mask=in_segs,
                                 pred=in_pred,
                                 name=model_name,
                                 epoch=epoch + 1)

                running_loss += loss.detach().item()
                dice_score = dice_coe(outputs.detach().cpu(),
                                      segs.detach().cpu())
                dice_coeff_bg += dice_score['BG']
                dice_coeff_ed += dice_score['ED']
                dice_coeff_et += dice_score['ET']
                dice_coeff_net += dice_score['NET']

                # show progress bar
                pbar.set_postfix(
                    **{
                        'Training loss': loss.detach().item(),
                        'Training (avg) accuracy': dice_score['avg']
                    })
                pbar.update(images.shape[0])

                global_step += 1
                if global_step % nb_batches == 0:
                    # validate
                    net.eval()
                    val_loss, val_acc = validation(net, val_set, criterion,
                                                   device, batch_size)

        train_info['train_loss'].append(running_loss / nb_batches)
        train_info['val_loss'].append(val_loss)
        train_info['BG_acc'].append(dice_coeff_bg / nb_batches)
        train_info['NET_acc'].append(dice_coeff_net / nb_batches)
        train_info['ED_acc'].append(dice_coeff_ed / nb_batches)
        train_info['ET_acc'].append(dice_coeff_et / nb_batches)

        # save bast trained model
        scheduler.step(running_loss / nb_batches)

        if min_loss > val_loss:
            min_loss = val_loss
            is_best = True
            early_stop_count = 0
        else:
            is_best = False
            early_stop_count += 1

        state = {
            'epoch': epoch + 1,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': running_loss / nb_batches,
            'min_loss': min_loss
        }
        verbose = save_ckp(state,
                           is_best,
                           early_stop_count=early_stop_count,
                           early_stop_patience=early_stop_patience,
                           save_model_dir=model_path,
                           best_dir=best_path,
                           name=model_name)

        logger.info('The average training loss for this epoch is {}'.format(
            running_loss / (np.ceil(n_train / batch_size))))
        logger.info(
            'Validation dice loss: {}; Validation (avg) accuracy: {}'.format(
                val_loss, val_acc))
        logger.info('The best validation loss till now is {}'.format(min_loss))

        # save the training info every epoch
        logger.info('Writing the training info into file ...')
        with open(train_info_file, 'w') as fp:
            json.dump(train_info, fp)

        loss_plot(train_info_file, name=model_name)

        if verbose:
            logger.info(
                'The validation loss has not improved for {} epochs, training will stop here.'
                .format(early_stop_patience))
            break

    logger.info('finish training!')
Пример #7
0
def train(args):

    torch.cuda.manual_seed(1)
    torch.manual_seed(1)

    # user defined
    model_name = args.model_name
    model_type = args.model_type
    loss_func = args.loss
    world_size = args.world_size
    rank = args.rank
    base_channel = args.base_channels
    crop_size = args.crop_size
    ignore_idx = args.ignore_idx
    epochs = args.epoch

    # system setup
    config_file = 'config.yaml'
    config = load_config(config_file)
    labels = config['PARAMETERS']['labels']
    root_path = config['PATH']['model_root']
    model_dir = config['PATH']['save_ckp']
    best_dir = config['PATH']['save_best_model']

    output_channels = int(config['PARAMETERS']['output_channels'])
    batch_size = int(config['PARAMETERS']['batch_size'])
    is_best = bool(config['PARAMETERS']['is_best'])
    is_resume = bool(config['PARAMETERS']['resume'])
    patience = int(config['PARAMETERS']['patience'])
    time_step = int(config['PARAMETERS']['time_step'])
    num_workers = int(config['PARAMETERS']['num_workers'])
    early_stop_patience = int(config['PARAMETERS']['early_stop_patience'])
    pad_method = config['PARAMETERS']['pad_method']
    lr = int(config['PARAMETERS']['lr'])
    optimizer = config['PARAMETERS']['optimizer']
    softmax = True
    modalities = ['flair', 't1', 't1gd', 't2']
    input_modalites = len(modalities)

    # build up dirs
    model_path = os.path.join(root_path, model_dir)
    best_path = os.path.join(root_path, best_dir)
    intermidiate_data_save = os.path.join(root_path, 'train_newdata',
                                          model_name)
    train_info_file = os.path.join(intermidiate_data_save,
                                   '{}_train_info.json'.format(model_name))
    log_path = os.path.join(root_path, 'logfiles')

    if not os.path.exists(model_path):
        os.mkdir(model_path)
    if not os.path.exists(best_path):
        os.mkdir(best_path)
    if not os.path.exists(intermidiate_data_save):
        os.makedirs(intermidiate_data_save)
    if not os.path.exists(log_path):
        os.mkdir(log_path)

    log_name = model_name + '_' + config['PATH']['log_file']
    logger = logfile(os.path.join(log_path, log_name))
    logger.info('Dataset is loading ...')
    writer = SummaryWriter('ProcessVisu/%s' % model_name)

    logger.info('patch size: {}'.format(crop_size))

    # load training set and validation set
    data_class = data_split()
    train, val, test = data_construction(data_class)
    train_dict = time_parser(train, time_patch=time_step)
    val_dict = time_parser(val, time_patch=time_step)

    # groups = 4
    if model_type == 'UNet':
        net = init_U_Net(input_modalites, output_channels, base_channel,
                         pad_method, softmax)
    elif model_type == 'ResUNet':
        net = ResUNet(input_modalites, output_channels, base_channel,
                      pad_method, softmax)
    elif model_type == 'DResUNet':
        net = DResUNet(input_modalites, output_channels, base_channel,
                       pad_method, softmax)
    elif model_type == 'direct_concat':
        net = U_Net_direct_concat(input_modalites, output_channels,
                                  base_channel, pad_method, softmax)
    elif model_type == 'Inception':
        net = Inception_UNet(input_modalites, output_channels, base_channel,
                             softmax)
    elif model_type == 'Simple_Inception':
        net = Simplified_Inception_UNet(input_modalites, output_channels,
                                        base_channel, softmax)

    # device setup
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    net.to(device)

    # print model structure
    summary(net, input_size=(input_modalites, crop_size, crop_size, crop_size))
    dummy_input = torch.rand(1, input_modalites, crop_size, crop_size,
                             crop_size).to(device)
    writer.add_graph(net, (dummy_input, ))

    # loss and optimizer setup
    if loss_func == 'Dice' and softmax:
        criterion = DiceLoss(labels=labels, ignore_idx=ignore_idx)
    elif loss_func == 'GDice' and softmax:
        criterion = GneralizedDiceLoss(labels=labels)
    elif loss_func == 'CrossEntropy':
        criterion = WeightedCrossEntropyLoss(labels=labels)
        if not softmax:
            criterion = nn.CrossEntropyLoss().cuda()
    else:
        raise NotImplementedError()

    if optimizer == 'adam':
        optimizer = optim.Adam(net.parameters())
    elif optimizer == 'sgd':
        optimizer = optim.SGD(net.parameters(),
                              momentum=0.9,
                              lr=lr,
                              weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     verbose=True,
                                                     patience=patience)

    # net, optimizer = amp.initialize(net, optimizer, opt_level='O1')

    if torch.cuda.device_count() > 1:
        logger.info('{} GPUs avaliable'.format(torch.cuda.device_count()))
        torch.distributed.init_process_group(
            backend='nccl',
            init_method='tcp://127.0.0.1:38366',
            rank=rank,
            world_size=world_size)
    if distributed_is_initialized():
        logger.info('distributed is initialized')
        net.to(device)
        net = nn.parallel.DistributedDataParallel(net)
    else:
        logger.info('data parallel')
        net = nn.DataParallel(net)
        net.to(device)

    min_loss = float('Inf')
    early_stop_count = 0
    global_step = 0
    start_epoch = 0
    start_loss = 0
    train_info = {
        'train_loss': [],
        'val_loss': [],
        'label_0_acc': [],
        'label_1_acc': [],
        'label_2_acc': [],
        'label_3_acc': [],
        'label_4_acc': []
    }

    if is_resume:
        try:
            # open previous check points
            ckp_path = os.path.join(model_path,
                                    '{}_model_ckp.pth.tar'.format(model_name))
            net, optimizer, scheduler, start_epoch, min_loss, start_loss = load_ckp(
                ckp_path, net, optimizer, scheduler)
            # open previous training records
            with open(train_info_file) as f:
                train_info = json.load(f)

            logger.info(
                'Training loss from last time is {}'.format(start_loss) +
                '\n' +
                'Mininum training loss from last time is {}'.format(min_loss))
            logger.info(
                'Training accuracies from last time are: label 0: {}, label 1: {}, label 2: {}, label 3: {}, label 4: {}'
                .format(train_info['label_0_acc'][-1],
                        train_info['label_1_acc'][-1],
                        train_info['label_2_acc'][-1],
                        train_info['label_3_acc'][-1],
                        train_info['label_4_acc'][-1]))
            # min_loss = float('Inf')

        except:
            logger.warning(
                'No checkpoint available, strat training from scratch')

    # start training
    for epoch in range(start_epoch, epochs):

        # every epoch generate a new set of images
        train_set = data_loader(train_dict,
                                batch_size=batch_size,
                                key='train',
                                num_works=num_workers,
                                time_step=time_step,
                                patch=crop_size,
                                modalities=modalities,
                                model_type='CNN')
        n_train = len(train_set)
        train_loader = train_set.load()

        val_set = data_loader(val_dict,
                              batch_size=batch_size,
                              key='val',
                              num_works=num_workers,
                              time_step=time_step,
                              patch=crop_size,
                              modalities=modalities,
                              model_type='CNN')
        n_val = len(val_set)

        nb_batches = np.ceil(n_train / batch_size)
        n_total = n_train + n_val
        logger.info(
            '{} images will be used in total, {} for trainning and {} for validation'
            .format(n_total, n_train, n_val))
        logger.info('Dataset loading finished!')

        # setup to train mode
        net.train()
        running_loss = 0
        dice_score_label_0 = 0
        dice_score_label_1 = 0
        dice_score_label_2 = 0
        dice_score_label_3 = 0
        dice_score_label_4 = 0

        logger.info('Training epoch {} will begin'.format(epoch + 1))

        with tqdm(total=n_train,
                  desc=f'Epoch {epoch+1}/{epochs}',
                  unit='patch') as pbar:

            for i, data in enumerate(train_loader, 0):
                images, segs = data['image'].to(device), data['seg'].to(device)

                if model_type == 'SkipDenseSeg' and not softmax:
                    segs = segs.long()

                # combine the batch and time step
                batch, time, channel, z, y, x = images.shape
                images = images.view(-1, channel, z, y, x)
                segs = segs.view(-1, z, y, x)

                # zero the parameter gradients
                optimizer.zero_grad()
                outputs = net(images)

                loss = criterion(outputs, segs)
                loss.backward()
                # with amp.scale_loss(loss, optimizer) as scaled_loss:
                #     scaled_loss.backward()
                optimizer.step()

                running_loss += loss.detach().item()
                _, preds = torch.max(outputs.data, 1)
                dice_score = dice(preds.data.cpu(),
                                  segs.data.cpu(),
                                  ignore_idx=ignore_idx)

                dice_score_label_0 += dice_score['bg']
                dice_score_label_1 += dice_score['csf']
                dice_score_label_2 += dice_score['gm']
                dice_score_label_3 += dice_score['wm']
                dice_score_label_4 += dice_score['tm']

                # show progress bar
                pbar.set_postfix(
                    **{
                        'Training loss': loss.detach().item(),
                        'Training accuracy': dice_score['avg']
                    })
                pbar.update(images.shape[0])

                del images, segs

                global_step += 1
                if global_step % nb_batches == 0:
                    net.eval()
                    val_loss, val_acc, val_info = validation(
                        net,
                        val_set,
                        criterion,
                        device,
                        batch_size,
                        model_type=model_type,
                        softmax=softmax,
                        ignore_idx=ignore_idx)

        train_info['train_loss'].append(running_loss / nb_batches)
        train_info['val_loss'].append(val_loss)
        train_info['label_0_acc'].append(dice_score_label_0 / nb_batches)
        train_info['label_1_acc'].append(dice_score_label_1 / nb_batches)
        train_info['label_2_acc'].append(dice_score_label_2 / nb_batches)
        train_info['label_3_acc'].append(dice_score_label_3 / nb_batches)
        train_info['label_4_acc'].append(dice_score_label_4 / nb_batches)

        # save bast trained model
        if model_type == 'SkipDenseSeg':
            scheduler.step()
        else:
            scheduler.step(val_loss)
        # debug
        for param_group in optimizer.param_groups:
            logger.info('%0.6f | %6d ' % (param_group['lr'], epoch))

        if min_loss > running_loss / nb_batches + 1e-2:
            min_loss = running_loss / nb_batches
            is_best = True
            early_stop_count = 0
        else:
            is_best = False
            early_stop_count += 1

        # save the check point
        state = {
            'epoch': epoch + 1,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': running_loss / nb_batches,
            'min_loss': min_loss
        }
        verbose = save_ckp(state,
                           is_best,
                           early_stop_count=early_stop_count,
                           early_stop_patience=early_stop_patience,
                           save_model_dir=model_path,
                           best_dir=best_path,
                           name=model_name)

        # summarize the training results of this epoch
        logger.info('Average training loss of this epoch is {}'.format(
            running_loss / nb_batches))
        logger.info('Best training loss till now is {}'.format(min_loss))
        logger.info('Validation dice loss: {}; Validation accuracy: {}'.format(
            val_loss, val_acc))

        # save the training info every epoch
        logger.info('Writing the training info into file ...')
        val_info_file = os.path.join(intermidiate_data_save,
                                     '{}_val_info.json'.format(model_name))
        with open(train_info_file, 'w') as fp:
            json.dump(train_info, fp)
        with open(val_info_file, 'w') as fp:
            json.dump(val_info, fp)

        loss_plot(train_info_file, name=model_name)
        for name, layer in net.named_parameters():
            writer.add_histogram(name + '_grad',
                                 layer.grad.cpu().data.numpy(), epoch)
            writer.add_histogram(name + '_data',
                                 layer.cpu().data.numpy(), epoch)

        if verbose:
            logger.info(
                'The validation loss has not improved for {} epochs, training will stop here.'
                .format(early_stop_patience))
            break

    writer.close()
    logger.info('finish training!')
Пример #8
0
      optimizerI.step()
    print("Epoch"+str(epoch),"Step"+str(step),abs(err.item()),abs(t_c_loss.item()))
    if(step%200==0):
      checkpoint_dynamic = {
          'epoch': epoch + 1,
          'state_dict': net_dynamic.state_dict(),
          'optimizer': optimizerD.state_dict(),
      }

      checkpoint_inpainter = {
          'epoch': epoch + 1,
          'state_dict': net_impainter.state_dict(),
          'optimizer': optimizerI.state_dict(),
      }

      save_ckp(checkpoint_dynamic, checkpoint_dynamic_path+"checkpoint_"+str(epoch+1)+".pt")
      save_ckp(checkpoint_inpainter, checkpoint_inpainter_path+"checkpoint_"+str(epoch+1)+".pt")
 

    step+=1
    
    # break
  checkpoint_dynamic = {
      'epoch': epoch + 1,
      'state_dict': net_dynamic.state_dict(),
      'optimizer': optimizerD.state_dict(),
  }

  checkpoint_inpainter = {
      'epoch': epoch + 1,
      'state_dict': net_impainter.state_dict(),