コード例 #1
0
 def __init__(self, num_classes, optimizer, lr_args, optimizer_args):
     self.lr_scheduler_args = lr_args
     self.optimizer_args = optimizer_args
     self.optimizer = optimizer
     self.loss_func = CombinedLoss(weight_dice=0, weight_ce=100)
     self.num_classes = num_classes
     self.classes = list(range(self.num_classes))
コード例 #2
0
    def update_model(self, round_num, caption) :
        # Remove <start> and <end> if they're part of caption
        if caption[:7] == '<start>' :
            caption = caption[8:-6]

        # don't update if caption is empty
        if(len(caption.split()) < 1) :
            return

        combined_loss = CombinedLoss(self)
        data_loader = get_reduction_loader(
            self.raw_image, self.vocab, self.batch_size, caption, self.dataset_type,
            shuffle=True, num_workers=self.num_workers
        )
        
        # define optimizer
        params = list(self.decoder.parameters())
        optimizer = torch.optim.Adam(params, lr=self.learning_rate)

        # Keep training until we hit specified number of gradient steps
        steps = 0
        while True :
            for i, batch in enumerate(data_loader):
#                print('num reductions for reduction', self.dataset_type, ':', batch[1].size())
                if steps==self.num_steps:
                    break
                loss = combined_loss.compute(batch, steps)
                self.decoder.zero_grad()
                loss.backward()
                optimizer.step()
                steps += 1

            if steps==self.num_steps :
                break

        # After adaptation, add current trial's data to 'memory' for future rounds
        self.history.append({'target': self.raw_image, 'cap': caption})

        # precompute history captions so we don't have to do it again on every step
        for reduced_cap in build_dataset(caption, self.dataset_type) :
            self.orig_captions.append((self.raw_image, reduced_cap))

        # Save the model checkpoints
        if(self.checkpoint) :
            ckpt_loc = 'decoder-{}.ckpt'.format(self.gameid)
            torch.save(self.decoder.state_dict(),
                       os.path.join(self.model_path, ckpt_loc))
コード例 #3
0
def main(argv):
    params = args_parsing(cmd_args_parsing(argv))
    root, experiment_name, image_size, batch_size, lr, n_epochs, log_dir, checkpoint_path = (
        params['root'], params['experiment_name'], params['image_size'],
        params['batch_size'], params['lr'], params['n_epochs'],
        params['log_dir'], params['checkpoint_path'])

    train_val_split(os.path.join(root, DATASET_TABLE_PATH))
    dataset = pd.read_csv(os.path.join(root, DATASET_TABLE_PATH))

    pre_transforms = torchvision.transforms.Compose(
        [Resize(size=image_size), ToTensor()])
    batch_transforms = torchvision.transforms.Compose(
        [BatchEncodeSegmentaionMap()])
    augmentation_batch_transforms = torchvision.transforms.Compose([
        BatchToPILImage(),
        BatchHorizontalFlip(p=0.5),
        BatchRandomRotation(degrees=10),
        BatchRandomScale(scale=(1.0, 2.0)),
        BatchBrightContrastJitter(brightness=(0.5, 2.0), contrast=(0.5, 2.0)),
        BatchToTensor(),
        BatchEncodeSegmentaionMap()
    ])

    train_dataset = SegmentationDataset(
        dataset=dataset[dataset['phase'] == 'train'], transform=pre_transforms)

    train_sampler = SequentialSampler(train_dataset)
    train_batch_sampler = BatchSampler(train_sampler, batch_size)
    train_collate = collate_transform(augmentation_batch_transforms)
    train_dataloader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_sampler=train_batch_sampler,
        collate_fn=train_collate)

    val_dataset = SegmentationDataset(
        dataset=dataset[dataset['phase'] == 'val'], transform=pre_transforms)

    val_sampler = SequentialSampler(val_dataset)
    val_batch_sampler = BatchSampler(val_sampler, batch_size)
    val_collate = collate_transform(batch_transforms)
    val_dataloader = torch.utils.data.DataLoader(
        dataset=val_dataset,
        batch_sampler=val_batch_sampler,
        collate_fn=val_collate)

    # model = Unet_with_attention(1, 2, image_size[0], image_size[1]).to(device)
    # model = UNet(1, 2).to(device)
    # model = UNetTC(1, 2).to(device)

    model = UNetFourier(1, 2, image_size, fourier_layer='linear').to(device)

    writer, experiment_name, best_model_path = setup_experiment(
        model.__class__.__name__, log_dir, experiment_name)

    new_checkpoint_path = os.path.join(root, 'checkpoints',
                                       experiment_name + '_latest.pth')
    best_checkpoint_path = os.path.join(root, 'checkpoints',
                                        experiment_name + '_best.pth')
    os.makedirs(os.path.dirname(new_checkpoint_path), exist_ok=True)

    if checkpoint_path is not None:
        checkpoint_path = os.path.join(root, 'checkpoints', checkpoint_path)
        print(f"\nLoading checkpoint from {checkpoint_path}.\n")
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = None
    best_model_path = os.path.join(root, best_model_path)
    print(f"Experiment name: {experiment_name}")
    print(f"Model has {count_parameters(model):,} trainable parameters")
    print()

    criterion = CombinedLoss(
        [CrossEntropyLoss(),
         GeneralizedDiceLoss(weighted=True)], [0.4, 0.6])
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           'min',
                                                           factor=0.5,
                                                           patience=5)
    metric = DiceMetric()
    weighted_metric = DiceMetric(weighted=True)

    print(
        "To see the learning process, use command in the new terminal:\ntensorboard --logdir <path to log directory>"
    )
    print()
    train(model, train_dataloader, val_dataloader, criterion, optimizer,
          scheduler, metric, weighted_metric, n_epochs, device, writer,
          best_model_path, best_checkpoint_path, checkpoint,
          new_checkpoint_path)
コード例 #4
0
    print(
        f'Number of total params: {sum([np.prod(p.shape) for p in model.parameters()])}'
    )

    if start_epoch >= args.num_epochs:
        print(
            'The model has already been trained for the number of epochs required'
        )

    if not os.path.exists(args.log):
        os.makedirs(args.log)
    index = 0 if len(os.listdir(
        args.log)) == 0 else int(sorted(os.listdir(args.log)).pop()[:4]) + 1
    args.log = os.path.join(args.log, '%.4d-train' % index)
    os.makedirs(args.log)
    print(f'==> Saving logs to {args.log}')

    solver = Solver(model=model,
                    optimizer=optimizer,
                    criterion=CombinedLoss(),
                    start_epoch=start_epoch,
                    num_epochs=args.num_epochs,
                    device=device,
                    log_dir=args.log,
                    checkpoint_interval=args.checkpoint_interval,
                    amp=amp if args.amp else None)

    print(f'==> Training for {args.num_epochs} epochs...')
    solver.train(train_dataloader, test_dataloader)
コード例 #5
0
            sorted(glob.glob(os.path.join(args.log, args.model, 'checkpoint*.pkl'))).pop(),
            map_location = lambda _, __: _
        )
        print(f'Model trained for {checkpoint["epoch"] + 1} epoch(s)...')
        start_epoch = checkpoint["epoch"] + 1
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if args.amp:
            amp.load_state_dict(checkpoint['amp'])
    
    else:
        raise FileNotFoundError(f'Checkpoint not found at {args.model}!')  
        
    print(f'Number of total params: {sum([np.prod(p.shape) for p in model.parameters()])}')

    if not os.path.exists(args.log):
        os.makedirs(args.log)
    index = int(sorted(os.listdir(args.log), key = lambda x: x[:4]).pop()[:4]) + 1
    args.log = os.path.join(args.log, '%.4d-val' % index)
    os.makedirs(args.log)

    print(f'==> Saving logs to {args.log}')

    solver = Solver(
        model = model, optimizer = optimizer, criterion = CombinedLoss(),
        start_epoch = None, num_epochs = None, device = device, 
        log_dir = args.log,  checkpoint_interval = None, amp = amp if args.amp else None
    )
    print(f'==> Validating...')
    solver.validate(test_dataloader)
    
コード例 #6
0
def run_epoch(model,
              iterator,
              optimizer,
              metric,
              weighted_metric=None,
              phase='train',
              epoch=0,
              device='cpu',
              writer=None):
    is_train = (phase == 'train')
    if is_train:
        model.train()
    else:
        model.eval()

    criterion_bce = torch.nn.BCELoss()
    criterion_dice = DiceLoss()

    epoch_loss = 0.0
    epoch_metric = 0.0
    if weighted_metric is not None:
        epoch_weighted_metric = 0.0

    with torch.set_grad_enabled(is_train):
        batch_to_plot = np.random.choice(range(len(iterator)))
        for i, (images, masks) in enumerate(tqdm(iterator)):
            images, masks = images.to(device), masks.to(device)

            # predicted_masks = model(images)

            # loss = criterion(predicted_masks, masks)

            if is_train:

                outputs1, outputs2, outputs3, outputs4, outputs1_1, outputs1_2, outputs1_3, outputs1_4, output = model(
                    images)

                predicted_masks = output

                output = F.sigmoid(output)
                outputs1 = F.sigmoid(outputs1)
                outputs2 = F.sigmoid(outputs2)
                outputs3 = F.sigmoid(outputs3)
                outputs4 = F.sigmoid(outputs4)
                outputs1_1 = F.sigmoid(outputs1_1)
                outputs1_2 = F.sigmoid(outputs1_2)
                outputs1_3 = F.sigmoid(outputs1_3)
                outputs1_4 = F.sigmoid(outputs1_4)

                label = masks.to(torch.float)

                loss0_bce = criterion_bce(output, label)
                loss1_bce = criterion_bce(outputs1, label)
                loss2_bce = criterion_bce(outputs2, label)
                loss3_bce = criterion_bce(outputs3, label)
                loss4_bce = criterion_bce(outputs4, label)
                loss5_bce = criterion_bce(outputs1_1, label)
                loss6_bce = criterion_bce(outputs1_2, label)
                loss7_bce = criterion_bce(outputs1_3, label)
                loss8_bce = criterion_bce(outputs1_4, label)

                loss0_dice = criterion_dice(output, label)
                loss1_dice = criterion_dice(outputs1, label)
                loss2_dice = criterion_dice(outputs2, label)
                loss3_dice = criterion_dice(outputs3, label)
                loss4_dice = criterion_dice(outputs4, label)
                loss5_dice = criterion_dice(outputs1_1, label)
                loss6_dice = criterion_dice(outputs1_2, label)
                loss7_dice = criterion_dice(outputs1_3, label)
                loss8_dice = criterion_dice(outputs1_4, label)

                loss = loss0_bce + 0.4 * loss1_bce + 0.5 * loss2_bce + 0.7 * loss3_bce + 0.8 * loss4_bce + \
                    0.4 * loss5_bce + 0.5 * loss6_bce + 0.7 * loss7_bce + 0.8 * loss8_bce + \
                    loss0_dice + 0.4 * loss1_dice + 0.5 * loss2_dice + 0.7 * loss3_dice + 0.8 * loss4_dice + \
                    0.4 * loss5_dice + 0.7 * loss6_dice + 0.8 * loss7_dice + 1 * loss8_dice

            else:
                predict = model(images)
                predicted_masks = F.sigmoid(predict).cpu().numpy()

                # predicted_masks_0 = predicted_masks <= 0.5
                # predicted_masks_1 = predicted_masks > 0.5
                predicted_masks_0 = 1 - predicted_masks
                predicted_masks_1 = predicted_masks
                predicted_masks = np.concatenate(
                    [predicted_masks_0, predicted_masks_1], axis=1)

                criterion = CombinedLoss(
                    [CrossEntropyLoss(),
                     GeneralizedDiceLoss(weighted=True)], [0.4, 0.6])
                # print(predicted_masks.shape)
                # print(masks.shape)
                predicted_masks = torch.tensor(predicted_masks).to(device)
                loss = criterion(predicted_masks.to(torch.float), masks)

            if is_train:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            epoch_loss += loss.item()
            epoch_metric += metric(torch.argmax(predicted_masks, dim=1), masks)
            if weighted_metric is not None:
                epoch_weighted_metric += weighted_metric(
                    torch.argmax(predicted_masks, dim=1), masks)

            if i == batch_to_plot:
                images_to_plot, masks_to_plot, predicted_masks_to_plot = process_to_plot(
                    images, masks, predicted_masks)

        if writer is not None:
            writer.add_scalar(f"loss_epoch/{phase}",
                              epoch_loss / len(iterator), epoch)
            writer.add_scalar(f"metric_epoch/{phase}",
                              epoch_metric / len(iterator), epoch)
            if weighted_metric is not None:
                writer.add_scalar(f"weighted_metric_epoch/{phase}",
                                  epoch_weighted_metric / len(iterator), epoch)

            # show images from last batch

            # send to tensorboard them to tensorboard
            writer.add_images(tag='images',
                              img_tensor=images_to_plot,
                              global_step=epoch + 1)
            writer.add_images(tag='true masks',
                              img_tensor=masks_to_plot,
                              global_step=epoch + 1)
            writer.add_images(tag='predicted masks',
                              img_tensor=predicted_masks_to_plot,
                              global_step=epoch + 1)

        if weighted_metric is not None:
            return epoch_loss / len(iterator), epoch_metric / len(
                iterator), epoch_weighted_metric / len(iterator)
        return epoch_loss / len(iterator), epoch_metric / len(iterator), None
コード例 #7
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('-o', '--output_dir', default=None, help='output dir')
    parser.add_argument('-b',
                        '--batch_size',
                        type=int,
                        default=1,
                        metavar='N',
                        help='input batch size for training')
    parser.add_argument('--epochs',
                        type=int,
                        default=400,
                        help='number of epochs to train')
    parser.add_argument('-lr',
                        '--lr',
                        type=float,
                        default=0.01,
                        metavar='LR',
                        help='learning rate')
    parser.add_argument(
        '-reset_lr',
        '--reset_lr',
        action='store_true',
        help='should reset lr cycles? If not count epochs from 0')
    parser.add_argument('-opt',
                        '--optimizer',
                        default='sgd',
                        choices=['sgd', 'adam', 'rmsprop'],
                        help='optimizer type')
    parser.add_argument('--decay_step',
                        type=float,
                        default=100,
                        metavar='EPOCHS',
                        help='learning rate decay step')
    parser.add_argument('--decay_gamma',
                        type=float,
                        default=0.5,
                        help='learning rate decay coeeficient')
    parser.add_argument(
        '--cyclic_lr',
        type=int,
        default=None,
        help=
        '(int)Len of the cycle. If not None use cyclic lr with cycle_len) specified'
    )
    parser.add_argument(
        '--cyclic_duration',
        type=float,
        default=1.0,
        help='multiplier of the duration of segments in the cycle')

    parser.add_argument('--weight_decay',
                        type=float,
                        default=0.0005,
                        help='L2 regularizer weight')
    parser.add_argument('--seed', type=int, default=1993, help='random seed')
    parser.add_argument(
        '--log_aggr',
        type=int,
        default=None,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('-gacc',
                        '--num_grad_acc_steps',
                        type=int,
                        default=1,
                        metavar='N',
                        help='number of vatches to accumulate gradients')
    parser.add_argument(
        '-imsize',
        '--image_size',
        type=int,
        default=1024,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('-f',
                        '--fold',
                        type=int,
                        default=0,
                        metavar='N',
                        help='fold_id')
    parser.add_argument('-nf',
                        '--n_folds',
                        type=int,
                        default=0,
                        metavar='N',
                        help='number of folds')
    parser.add_argument(
        '-fv',
        '--folds_version',
        type=int,
        default=1,
        choices=[1, 2],
        help='version of folds (1) - random, (2) - stratified on mask area')
    parser.add_argument('-group',
                        '--group',
                        type=parse_group,
                        default='all',
                        help='group id')
    parser.add_argument('-no_cudnn',
                        '--no_cudnn',
                        action='store_true',
                        help='dont use cudnn?')
    parser.add_argument('-aug',
                        '--aug',
                        type=int,
                        default=None,
                        help='use augmentations?')
    parser.add_argument('-no_hq',
                        '--no_hq',
                        action='store_true',
                        help='do not use hq images?')
    parser.add_argument('-dbg', '--dbg', action='store_true', help='is debug?')
    parser.add_argument('-is_log_dice',
                        '--is_log_dice',
                        action='store_true',
                        help='use -log(dice) in loss?')
    parser.add_argument('-no_weight_loss',
                        '--no_weight_loss',
                        action='store_true',
                        help='do not weight border in loss?')

    parser.add_argument('-suf',
                        '--exp_suffix',
                        default='',
                        help='experiment suffix')
    parser.add_argument('-net', '--network', default='Unet')

    args = parser.parse_args()
    print 'aug:', args.aug
    # assert args.aug, 'Careful! No aug specified!'
    if args.log_aggr is None:
        args.log_aggr = 1
    print 'log_aggr', args.log_aggr

    random.seed(42)
    torch.manual_seed(args.seed)
    print 'CudNN:', torch.backends.cudnn.version()
    print 'Run on {} GPUs'.format(torch.cuda.device_count())
    torch.backends.cudnn.benchmark = not args.no_cudnn  # Enable use of CudNN

    experiment = "{}_s{}_im{}_gacc{}{}{}{}_{}fold{}.{}".format(
        args.network, args.seed, args.image_size, args.num_grad_acc_steps,
        '_aug{}'.format(args.aug) if args.aug is not None else '',
        '_nohq' if args.no_hq else '',
        '_g{}'.format(args.group) if args.group != 'all' else '',
        'v2' if args.folds_version == 2 else '', args.fold, args.n_folds)
    if args.output_dir is None:
        ckpt_dir = join(config.models_dir, experiment + args.exp_suffix)
        if os.path.exists(join(ckpt_dir, 'checkpoint.pth.tar')):
            args.output_dir = ckpt_dir
    if args.output_dir is not None and os.path.exists(args.output_dir):
        ckpt_path = join(args.output_dir, 'checkpoint.pth.tar')
        if not os.path.isfile(ckpt_path):
            print "=> no checkpoint found at '{}'\nUsing model_best.pth.tar".format(
                ckpt_path)
            ckpt_path = join(args.output_dir, 'model_best.pth.tar')
        if os.path.isfile(ckpt_path):
            print("=> loading checkpoint '{}'".format(ckpt_path))
            checkpoint = torch.load(ckpt_path)
            if 'filter_sizes' in checkpoint:
                filters_sizes = checkpoint['filter_sizes']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                ckpt_path, checkpoint['epoch']))
        else:
            raise IOError("=> no checkpoint found at '{}'".format(ckpt_path))
    else:
        checkpoint = None
        if args.network == 'Unet':
            filters_sizes = np.asarray([32, 64, 128, 256, 512, 1024, 1024])
        elif args.network == 'UNarrow':
            filters_sizes = np.asarray([32, 32, 64, 128, 256, 512, 768])
        elif args.network == 'Unet7':
            filters_sizes = np.asarray(
                [48, 96, 128, 256, 512, 1024, 1536, 1536])
        elif args.network == 'Unet5':
            filters_sizes = np.asarray([32, 64, 128, 256, 512, 1024])
        elif args.network == 'Unet4':
            filters_sizes = np.asarray([24, 64, 128, 256, 512])
        elif args.network in ['vgg11v1', 'vgg11v2']:
            filters_sizes = np.asarray([64])
        elif args.network in ['vgg11av1', 'vgg11av2']:
            filters_sizes = np.asarray([32])
        else:
            raise ValueError('Unknown Net: {}'.format(args.network))
    if args.network in ['vgg11v1', 'vgg11v2']:
        assert args.network[-2] == 'v'
        v = int(args.network[-1:])
        model = torch.nn.DataParallel(
            UnetVgg11(n_classes=1, num_filters=filters_sizes.item(),
                      v=v)).cuda()
    elif args.network in ['vgg11av1', 'vgg11av2']:
        assert args.network[-2] == 'v'
        v = int(args.network[-1:])
        model = torch.nn.DataParallel(
            vgg_unet.Vgg11a(n_classes=1, num_filters=filters_sizes.item(),
                            v=v)).cuda()
    else:
        unet_class = getattr(unet, args.network)
        model = torch.nn.DataParallel(
            unet_class(is_deconv=False, filters=filters_sizes)).cuda()

    print('  + Number of params: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    rescale_size = (args.image_size, args.image_size)
    is_full_size = False
    if args.image_size == -1:
        print 'Use full size. Use padding'
        is_full_size = True
        rescale_size = (1920, 1280)
    elif args.image_size == -2:
        rescale_size = (1856, 1248)

    train_dataset = CarvanaPlus(
        root=config.input_data_dir,
        subset='train',
        image_size=args.image_size,
        transform=TrainTransform(
            rescale_size,
            aug=args.aug,
            resize_mask=True,
            should_pad=is_full_size,
            should_normalize=args.network.startswith('vgg')),
        seed=args.seed,
        is_hq=not args.no_hq,
        fold_id=args.fold,
        n_folds=args.n_folds,
        group=args.group,
        return_image_id=True,
        v=args.folds_version)
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=4 if torch.cuda.device_count() > 1 else 1)

    val_dataset = CARVANA(
        root=config.input_data_dir,
        subset='val',
        image_size=args.image_size,
        transform=TrainTransform(
            rescale_size,
            aug=None,
            resize_mask=False,
            should_pad=is_full_size,
            should_normalize=args.network.startswith('vgg')),
        seed=args.seed,
        is_hq=not args.no_hq,
        fold_id=args.fold,
        n_folds=args.n_folds,
        group=args.group,
        v=args.folds_version,
    )
    val_loader = torch.utils.data.DataLoader(
        dataset=val_dataset,
        batch_size=args.batch_size * 2,
        shuffle=False,
        pin_memory=True,
        num_workers=4
        if torch.cuda.device_count() > 4 else torch.cuda.device_count())

    print 'Weight loss:', not args.no_weight_loss
    print '-log(dice) in loss:', args.is_log_dice
    criterion = CombinedLoss(is_weight=not args.no_weight_loss,
                             is_log_dice=args.is_log_dice).cuda()

    if args.optimizer == 'adam':
        print 'Using adam optimizer!'
        optimizer = optim.Adam(model.parameters(),
                               weight_decay=args.weight_decay,
                               lr=args.lr)
    elif args.optimizer == 'rmsprop':
        optimizer = optim.RMSprop(
            model.parameters(), lr=args.lr,
            weight_decay=args.weight_decay)  # For Tiramisu weight_decay=0.0001
    else:
        optimizer = optim.SGD(model.parameters(),
                              weight_decay=args.weight_decay,
                              lr=args.lr,
                              momentum=0.9,
                              nesterov=False)

    if args.output_dir is not None:
        out_dir = args.output_dir
    else:
        out_dir = join(config.models_dir, experiment + args.exp_suffix)
    print 'Model dir:', out_dir
    if args.dbg:
        out_dir = 'dbg_runs'
    logger = SummaryWriter(log_dir=out_dir)

    if checkpoint is not None:
        start_epoch = checkpoint['epoch']
        best_score = checkpoint['best_score']
        print 'Best score:', best_score
        print 'Current score:', checkpoint['cur_score']
        model.load_state_dict(checkpoint['state_dict'])
        print 'state dict loaded'
        optimizer.load_state_dict(checkpoint['optimizer'])
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.lr
            param_group['initial_lr'] = args.lr
        # validate(val_loader, model, start_epoch * len(train_loader), logger,
        #   is_eval=args.batch_size > 1, is_full_size=is_full_size)
    else:
        start_epoch = 0
        best_score = 0
        # validate(val_loader, model, start_epoch * len(train_loader), logger,
        #          is_eval=args.batch_size > 1, is_full_size=is_full_size)

    if args.cyclic_lr is None:
        scheduler = StepLR(optimizer,
                           step_size=args.decay_step,
                           gamma=args.decay_gamma)
        print 'scheduler.base_lrs=', scheduler.base_lrs
    elif args.network.startswith('vgg'):
        print 'Using VggCyclic LR!'
        cyclic_lr = VggCyclicLr(start_epoch if args.reset_lr else 0,
                                init_lr=args.lr,
                                num_epochs_per_cycle=args.cyclic_lr,
                                duration=args.cyclic_duration)
        scheduler = LambdaLR(optimizer, lr_lambda=cyclic_lr)
        scheduler.base_lrs = list(
            map(lambda group: 1.0, optimizer.param_groups))
    else:
        print 'Using Cyclic LR!'
        cyclic_lr = CyclicLr(start_epoch if args.reset_lr else 0,
                             init_lr=args.lr,
                             num_epochs_per_cycle=args.cyclic_lr,
                             epochs_pro_decay=args.decay_step,
                             lr_decay_factor=args.decay_gamma)
        scheduler = LambdaLR(optimizer, lr_lambda=cyclic_lr)
        scheduler.base_lrs = list(
            map(lambda group: 1.0, optimizer.param_groups))

    logger.add_scalar('data/batch_size', args.batch_size, start_epoch)
    logger.add_scalar('data/num_grad_acc_steps', args.num_grad_acc_steps,
                      start_epoch)
    logger.add_text('config/info', 'filters sizes: {}'.format(filters_sizes))

    last_lr = 100500

    for epoch in range(start_epoch, args.epochs):
        # train for one epoch
        scheduler.step(epoch=epoch)
        if last_lr != scheduler.get_lr()[0]:
            last_lr = scheduler.get_lr()[0]
            print 'LR := {}'.format(last_lr)
        logger.add_scalar('data/lr', scheduler.get_lr()[0], epoch)
        logger.add_scalar('data/aug', args.aug if args.aug is not None else -1,
                          epoch)
        logger.add_scalar('data/weight_decay', args.weight_decay, epoch)
        logger.add_scalar('data/is_weight_loss', not args.no_weight_loss,
                          epoch)
        logger.add_scalar('data/is_log_dice', args.is_log_dice, epoch)
        train(train_loader,
              model,
              optimizer,
              epoch,
              args.epochs,
              criterion,
              num_grad_acc_steps=args.num_grad_acc_steps,
              logger=logger,
              log_aggr=args.log_aggr)
        dice_score = validate(val_loader,
                              model,
                              epoch + 1,
                              logger,
                              is_eval=args.batch_size > 1,
                              is_full_size=is_full_size)

        # store best loss and save a model checkpoint
        is_best = dice_score > best_score
        prev_best_score = best_score
        best_score = max(dice_score, best_score)
        ckpt_dict = {
            'epoch': epoch + 1,
            'arch': experiment,
            'state_dict': model.state_dict(),
            'best_score': best_score,
            'cur_score': dice_score,
            'optimizer': optimizer.state_dict(),
        }
        ckpt_dict['filter_sizes'] = filters_sizes

        if is_best:
            print 'Best snapshot! {} -> {}'.format(prev_best_score, best_score)
            logger.add_text('val/best_dice',
                            'best val dice score: {}'.format(dice_score),
                            global_step=epoch + 1)
        save_checkpoint(ckpt_dict,
                        is_best,
                        filepath=join(out_dir, 'checkpoint.pth.tar'))

    logger.close()