Exemplo n.º 1
0
def train_model(model,
                criterion1,
                criterion2,
                criterion3,
                optimizer,
                scheduler=None,
                save_path=None,
                num_epochs=25,
                iter_size=1,
                compare='loss'):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_val = -sys.maxsize
    monitor = MetricMonitor()

    for epoch in range(num_epochs):
        # Each epoch has a training and validation phase
        for phase in ['train', 'valid']:
            model.train(
                phase == 'train')  # Set model to training/evaluate mode
            optimizer.zero_grad()
            monitor.reset()
            stream = tqdm(dataloaders[phase], file=sys.stdout)
            # Iterate over data.
            for i, samples in enumerate(stream, start=1):
                # get the inputs
                inputs = torch.tensor(samples['image'],
                                      requires_grad=True).cuda(async=True)
                # get the targets
                vectors = torch.tensor(samples['vectors'],
                                       dtype=torch.float).cuda(async=True)
                masks = torch.tensor(samples['masks'],
                                     dtype=torch.float).cuda(async=True)
                areas = torch.tensor(samples['areas'],
                                     dtype=torch.float).cuda(async=True)

                # forward
                outputs1, outputs2, outputs3 = model(inputs)
                loss1 = criterion1(inputs, outputs1, vectors, masks,
                                   areas) if criterion1 is not None else 0
                loss2 = criterion2(inputs, outputs2, vectors, masks,
                                   areas) if criterion2 is not None else 0
                loss3 = criterion3(inputs, outputs3, vectors, masks,
                                   areas) if criterion3 is not None else 0
                loss = loss1 + loss2 + loss3

                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    if i % iter_size == 0 or i == len(dataloaders[phase]):
                        optimizer.step()
                        optimizer.zero_grad()

                # statistics
                dice1 = dice_value(outputs1.data,
                                   torch.unsqueeze(masks[:, 0], 1).data)
                dice3 = dice_value(outputs3.data, masks.data)
                monitor.update('loss', loss.data, inputs.shape[0])
                monitor.update('dice1', dice1.data, inputs.shape[0])
                monitor.update('dice3', dice3.data, inputs.shape[0])
                stream.set_description(f'epoch {epoch+1}/{num_epochs} | '
                                       f'{phase}: {monitor}')
            stream.close()

            epoch_val = monitor.get_avg('dice1') if compare == 'dice1' else \
                       (monitor.get_avg('dice3') if compare == 'dice3' else -monitor.get_avg('loss'))

            if phase == 'valid' and scheduler is not None:
                scheduler.step(-epoch_val)

            # deep copy the model
            if (phase == 'valid') and (epoch_val > best_val):
                best_val = epoch_val
                best_model_wts = copy.deepcopy(model.state_dict())
                if save_path is not None:
                    path = save_path.format((epoch + 1),
                                            optimizer.param_groups[0]['lr'],
                                            abs(best_val))
                    torch.save(best_model_wts, path)
                    print('Weights of model saved at {}'.format(path))

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best Val: {:.4f}'.format(abs(best_val)))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model
Exemplo n.º 2
0
def train_model(model,
                criterion,
                optimizer,
                scheduler=None,
                save_path=None,
                num_epochs=25,
                iter_size=1):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_dice = 0
    monitor = MetricMonitor()

    for epoch in range(num_epochs):
        # Each epoch has a training and validation phase
        for phase in ['train', 'valid']:
            model.train(
                phase == 'train')  # Set model to training/evaluate mode
            optimizer.zero_grad()
            monitor.reset()
            stream = tqdm(dataloaders[phase], file=sys.stdout)
            # Iterate over data.
            for i, samples in enumerate(stream, start=1):
                # get the inputs
                inputs = torch.tensor(samples['image'],
                                      requires_grad=True).cuda(async=True)
                # get the targets
                targets = torch.tensor(samples['masks'],
                                       dtype=torch.long).cuda(async=True)

                # forward
                outputs = model(inputs)
                loss = criterion(outputs, targets)

                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    if i % iter_size == 0 or i == len(dataloaders[phase]):
                        optimizer.step()
                        optimizer.zero_grad()

                # statistics
                dice = dice_value(outputs.data, targets.data, None)
                monitor.update('loss', loss.data, inputs.shape[0])
                monitor.update('dice', dice.data, inputs.shape[0])
                stream.set_description(f'epoch {epoch+1}/{num_epochs} | '
                                       f'{phase}: {monitor}')
            stream.close()

            epoch_loss = monitor.get_avg('loss')
            epoch_dice = monitor.get_avg('dice')

            if phase == 'valid' and scheduler is not None:
                scheduler.step(-epoch_dice)

            # deep copy the model
            if (phase == 'valid') and (epoch_dice > best_dice):
                best_dice = epoch_dice
                best_model_wts = copy.deepcopy(model.state_dict())
                if save_path is not None:
                    path = save_path.format(best_dice)
                    torch.save(best_model_wts, path)
                    print('Weights of model saved at {}'.format(path))

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Dice: {:.4f}'.format(best_dice))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model
Exemplo n.º 3
0
def train_model(model,
                criterion,
                optimizer,
                scheduler=None,
                model_save_path=None,
                optim_save_path=None,
                log_save_path=None,
                num_epochs=25,
                iter_size=1,
                compare_Loss=False):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_val = -sys.maxsize
    monitor = MetricMonitor()
    log = open(log_save_path, 'a') if log_save_path is not None else \
          type('dummy', (object,), {'write': lambda x,y:0, 'flush': lambda x:0, 'close': lambda x:0})()
    log.write(f'Training start at {time.strftime("%Y-%m-%d %H:%M")}\n\n')

    for epoch in range(num_epochs):
        # Each epoch has a training and validation phase
        for phase in ['train', 'valid']:
            model.train(
                phase == 'train')  # Set model to training/evaluate mode
            optimizer.zero_grad()
            monitor.reset()
            stream = tqdm(dataloaders[phase], file=sys.stdout)
            # Iterate over data.
            for i, samples in enumerate(stream, start=1):
                # get the inputs
                inputs = torch.tensor(samples['image'],
                                      requires_grad=True).cuda(async=True)
                # get the targets
                masks = torch.tensor(samples['masks'],
                                     dtype=torch.long).cuda(async=True)
                centroids = torch.tensor(samples['centroids'],
                                         dtype=torch.long).cuda(async=True)
                targets = masks + centroids

                # forward
                outputs = model(inputs)
                # outputs = F.avg_pool2d(outputs, 4, 4)
                loss = criterion(outputs, targets)

                # out5,out4,out3,out2,out1,out_fuse = model(inputs)
                # loss5 = criterion(out5, targets)
                # loss4 = criterion(out4, targets)
                # loss3 = criterion(out3, targets)
                # loss2 = criterion(out2, targets)
                # loss1 = criterion(out1, targets)
                # loss_fuse = criterion(out_fuse, targets)
                # loss = loss5 + loss4 + loss3 + loss2 + loss1 + loss_fuse

                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    if i % iter_size == 0 or i == len(dataloaders[phase]):
                        optimizer.step()
                        optimizer.zero_grad()

                # statistics
                dice = ce_dice_value(outputs.data, targets.data, [0, 0, 1, 0])
                monitor.update('loss', loss.data, inputs.shape[0])
                monitor.update('dice', dice.data, inputs.shape[0])
                stream.set_description(
                    f'epoch {epoch+1}/{num_epochs} | {phase}: {monitor}')
            stream.close()

            epoch_loss = monitor.get_avg('loss')
            epoch_dice = monitor.get_avg('dice')
            epoch_val = epoch_dice if not compare_Loss else -epoch_loss

            log.write(
                f'epoch {epoch+1}/{num_epochs} | {phase}: {monitor} | lr {optimizer.param_groups[0]["lr"]:.0e}\n'
            )

            if phase == 'valid' and scheduler is not None:
                scheduler.step(-epoch_val)

            # save the model and optimizer
            if (phase == 'valid') and (epoch_val > best_val):
                best_val = epoch_val
                best_model_wts = copy.deepcopy(model.state_dict())
                if model_save_path is not None:
                    path = model_save_path.format((epoch + 1), abs(best_val))
                    torch.save(best_model_wts, path)
                    print(f'Weights of model saved at {path}')
                    log.write(f'Weights of model saved at {path}\n')
            if (phase == 'valid') and (optim_save_path is not None):
                path = optim_save_path.format((epoch + 1),
                                              optimizer.param_groups[0]['lr'])
                torch.save(optimizer.state_dict(), path)
            log.flush()

        log.write('\n')
        print()

    time_elapsed = time.time() - since
    print(
        f'Training complete in {(time_elapsed//60):.0f}m {(time_elapsed%60):.0f}s'
    )
    print(f'Best Val: {abs(best_val):.4f}')
    log.write(
        f'Training complete in {(time_elapsed//60):.0f}m {(time_elapsed%60):.0f}s\n'
    )
    log.write(f'Best Val: {abs(best_val):f}\n')
    log.close()

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model