Example #1
0
def main():

    #     writer = SummaryWriter()
    test_acc_list = []
    logfilename = os.path.join('.', 'log3.txt')
    init_logfile(logfilename,
                 "epoch\ttime\tlr\ttrain loss\ttrain acc\ttestloss\ttest acc")

    #     net = ResNet(BasicBlock, [3, 3, 3]).to(device)

    net = VGG_SNIP('D').to(device)
    #     criterion = nn.CrossEntropyLoss().to(device)
    criterion = nn.NLLLoss().to(device)
    optimizer = SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
    scheduler = MultiStepLR(optimizer, milestones=[80, 120], last_epoch=-1)
    train_loader, test_loader = get_cifar10_dataloaders(128, 128)

    keep_masks = SNIP(net, 0.05, train_loader, device)  # TODO: shuffle?
    apply_prune_mask(net, keep_masks)

    for epoch in range(160):
        before = time.time()
        train_loss, train_acc = train(train_loader,
                                      net,
                                      criterion,
                                      optimizer,
                                      epoch,
                                      device,
                                      100,
                                      display=True)
        test_loss, test_acc = test(test_loader,
                                   net,
                                   criterion,
                                   device,
                                   100,
                                   display=True)

        scheduler.step(epoch)
        after = time.time()

        log(
            logfilename, "{}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
                epoch, str(datetime.timedelta(seconds=(after - before))),
                scheduler.get_lr()[0], train_loss, train_acc, test_loss,
                test_acc))

        print("test_acc: ", test_acc)
    test_acc_list.append(test_acc)
    log(logfilename, "This is the test accuracy list for args.round.")
    log(logfilename, str(test_acc_list))
Example #2
0
def train(data_path, models_path, backend, snapshot, crop_x, crop_y,
          batch_size, alpha, epochs, start_lr, milestones, gpu):
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    net, starting_epoch = build_network(snapshot, backend)
    data_path = os.path.abspath(os.path.expanduser(data_path))
    models_path = os.path.abspath(os.path.expanduser(models_path))
    os.makedirs(models_path, exist_ok=True)
    '''
        To follow this training routine you need a DataLoader that yields the tuples of the following format:
        (Bx3xHxW FloatTensor x, BxHxW LongTensor y, BxN LongTensor y_cls) where
        x - batch of input images,
        y - batch of groung truth seg maps,
        y_cls - batch of 1D tensors of dimensionality N: N total number of classes, 
        y_cls[i, T] = 1 if class T is present in image i, 0 otherwise
    '''
    train_loader, class_weights, n_images = None, None, None

    optimizer = optim.Adam(net.parameters(), lr=start_lr)
    scheduler = MultiStepLR(optimizer,
                            milestones=[int(x) for x in milestones.split(',')])

    for epoch in range(starting_epoch, starting_epoch + epochs):
        seg_criterion = nn.NLLLoss2d(weight=class_weights)
        cls_criterion = nn.BCEWithLogitsLoss(weight=class_weights)
        epoch_losses = []
        train_iterator = tqdm(loader, total=max_steps // batch_size + 1)
        net.train()
        for x, y, y_cls in train_iterator:
            steps += batch_size
            optimizer.zero_grad()
            x, y, y_cls = Variable(x).cuda(), Variable(y).cuda(), Variable(
                y_cls).cuda()
            out, out_cls = net(x)
            seg_loss, cls_loss = seg_criterion(out, y), cls_criterion(
                out_cls, y_cls)
            loss = seg_loss + alpha * cls_loss
            epoch_losses.append(loss.data[0])
            status = '[{0}] loss = {1:0.5f} avg = {2:0.5f}, LR = {5:0.7f}'.format(
                epoch + 1, loss.data[0], np.mean(epoch_losses),
                scheduler.get_lr()[0])
            train_iterator.set_description(status)
            loss.backward()
            optimizer.step()
        scheduler.step()
        torch.save(
            net.state_dict(),
            os.path.join(models_path, '_'.join(["PSPNet",
                                                str(epoch + 1)])))
        train_loss = np.mean(epoch_losses)
Example #3
0
def train_folds(writers):
    kvs = GlobalKVS()
    for fold_id in kvs['cv_split_train']:
        kvs.update('cur_fold', fold_id)
        kvs.update('prev_model', None)
        print(colored('====> ', 'blue') + f'Training fold {fold_id}....')

        train_index, val_index = kvs['cv_split_train'][fold_id]
        train_loader, val_loader = session.init_loaders(
            kvs['metadata'].iloc[train_index], kvs['metadata'].iloc[val_index])

        net = init_model()
        optimizer = init_optimizer([{
            'params':
            net.module.classifier_kl.parameters()
        }, {
            'params':
            net.module.classifier_prog.parameters()
        }])

        scheduler = MultiStepLR(optimizer,
                                milestones=kvs['args'].lr_drop,
                                gamma=0.1)

        for epoch in range(kvs['args'].n_epochs):
            kvs.update('cur_epoch', epoch)
            if epoch == kvs['args'].unfreeze_epoch:
                print(colored('====> ', 'red') + 'Unfreezing the layers!')
                new_lr_drop_milestones = list(
                    map(lambda x: x - kvs['args'].unfreeze_epoch,
                        kvs['args'].lr_drop))
                optimizer.add_param_group(
                    {'params': net.module.features.parameters()})
                scheduler = MultiStepLR(optimizer,
                                        milestones=new_lr_drop_milestones,
                                        gamma=0.1)

            print(colored('====> ', 'red') + 'LR:', scheduler.get_lr())
            train_loss = prog_epoch_pass(net, optimizer, train_loader)
            val_out = prog_epoch_pass(net, None, val_loader)
            val_loss, val_ids, gt_progression, preds_progression, gt_kl, preds_kl = val_out
            log_metrics_prog(writers[fold_id], train_loss, val_loss,
                             gt_progression, preds_progression, gt_kl,
                             preds_kl)

            session.save_checkpoint(net, 'ap_prog', 'gt')
            scheduler.step()
Example #4
0
def train(args):
    train_loader, val_loader = prepare_data(args)
    net = build_net(args)
    optimizer = optim.SGD(net.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    if args.lr_decay == 'multistep':
        lr_scheduler = MultiStepLR(
            optimizer,
            milestones=[int(args.epoch * 0.5),
                        int(args.epoch * 0.75)],
            gamma=0.1)
    elif args.lr_decay == 'cos':
        lr_scheduler = CosineAnnealingLR(optimizer, T_max=args.epoch)
    elif args.lr_decay == 'step':
        lr_scheduler = StepLR(optimizer, step_size=50, gamma=0.1)
    best_acc = 0
    checkpoint = {}
    for epochid in range(args.epoch):
        print("==> Training Epoch %d, Learning Rate %.4f" %
              (epochid, lr_scheduler.get_lr()[0]))
        train_epoch(net, train_loader, optimizer, args)
        print('==> Validating ')
        acc = validate(net, val_loader, args)
        lr_scheduler.step()
        if acc > best_acc:
            best_acc = acc
            if args.cpu or len(args.gpus) == 1:
                # Use cpu or one single gpu to train the model
                checkpoint = net.state_dict()
            elif len(args.gpus) > 1:
                checkpoint = net.module.state_dict()

    fname = args.arch + '_' + str(best_acc) + '.pth.tar'
    os.makedirs(args.outdir, exist_ok=True)
    fname = os.path.join(args.outdir, fname)
    torch.save(checkpoint, fname)
    print('Best Accuracy: ', best_acc)
elif args.method == 'padam':
    import Padam
    optimizer = Padam.Padam(model.parameters(),
                            lr=args.lr,
                            partial=args.partial,
                            weight_decay=args.wd,
                            betas=betas)
else:
    print('Optimizer undefined!')

scheduler = MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)

for epoch in range(start_epoch + 1, args.Nepoch + 1):

    scheduler.step()
    print('\nEpoch: %d' % epoch, ' Learning rate:', scheduler.get_lr())
    model.train()  # Training

    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()

        def closure():
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            return loss

        optimizer.zero_grad()
Example #6
0
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    train_transform = T.Compose([
        T.RandomRotation(args.rotation),
        T.RandomResizedCrop(size=args.image_size, scale=args.resize_scale),
        T.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25),
        T.GaussianBlur(),
        T.ToTensor(), normalize
    ])
    val_transform = T.Compose(
        [T.Resize(args.image_size),
         T.ToTensor(), normalize])
    image_size = (args.image_size, args.image_size)
    heatmap_size = (args.heatmap_size, args.heatmap_size)
    source_dataset = datasets.__dict__[args.source]
    train_source_dataset = source_dataset(root=args.source_root,
                                          transforms=train_transform,
                                          image_size=image_size,
                                          heatmap_size=heatmap_size)
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)
    val_source_dataset = source_dataset(root=args.source_root,
                                        split='test',
                                        transforms=val_transform,
                                        image_size=image_size,
                                        heatmap_size=heatmap_size)
    val_source_loader = DataLoader(val_source_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   pin_memory=True)

    target_dataset = datasets.__dict__[args.target]
    train_target_dataset = target_dataset(root=args.target_root,
                                          transforms=train_transform,
                                          image_size=image_size,
                                          heatmap_size=heatmap_size)
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)
    val_target_dataset = target_dataset(root=args.target_root,
                                        split='test',
                                        transforms=val_transform,
                                        image_size=image_size,
                                        heatmap_size=heatmap_size)
    val_target_loader = DataLoader(val_target_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   pin_memory=True)

    print("Source train:", len(train_source_loader))
    print("Target train:", len(train_target_loader))
    print("Source test:", len(val_source_loader))
    print("Target test:", len(val_target_loader))

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # create model
    backbone = models.__dict__[args.arch](pretrained=True)
    upsampling = Upsampling(backbone.out_features)
    num_keypoints = train_source_dataset.num_keypoints
    model = RegDAPoseResNet(backbone,
                            upsampling,
                            256,
                            num_keypoints,
                            num_head_layers=args.num_head_layers,
                            finetune=True).to(device)
    # define loss function
    criterion = JointsKLLoss()
    pseudo_label_generator = PseudoLabelGenerator(num_keypoints,
                                                  args.heatmap_size,
                                                  args.heatmap_size)
    regression_disparity = RegressionDisparity(pseudo_label_generator,
                                               JointsKLLoss(epsilon=1e-7))

    # define optimizer and lr scheduler
    optimizer_f = SGD([
        {
            'params': backbone.parameters(),
            'lr': 0.1
        },
        {
            'params': upsampling.parameters(),
            'lr': 0.1
        },
    ],
                      lr=0.1,
                      momentum=args.momentum,
                      weight_decay=args.wd,
                      nesterov=True)
    optimizer_h = SGD(model.head.parameters(),
                      lr=1.,
                      momentum=args.momentum,
                      weight_decay=args.wd,
                      nesterov=True)
    optimizer_h_adv = SGD(model.head_adv.parameters(),
                          lr=1.,
                          momentum=args.momentum,
                          weight_decay=args.wd,
                          nesterov=True)
    lr_decay_function = lambda x: args.lr * (1. + args.lr_gamma * float(x))**(
        -args.lr_decay)
    lr_scheduler_f = LambdaLR(optimizer_f, lr_decay_function)
    lr_scheduler_h = LambdaLR(optimizer_h, lr_decay_function)
    lr_scheduler_h_adv = LambdaLR(optimizer_h_adv, lr_decay_function)
    start_epoch = 0

    if args.resume is None:
        if args.pretrain is None:
            # first pretrain the backbone and upsampling
            print("Pretraining the model on source domain.")
            args.pretrain = logger.get_checkpoint_path('pretrain')
            pretrained_model = PoseResNet(backbone, upsampling, 256,
                                          num_keypoints, True).to(device)
            optimizer = SGD(pretrained_model.get_parameters(lr=args.lr),
                            momentum=args.momentum,
                            weight_decay=args.wd,
                            nesterov=True)
            lr_scheduler = MultiStepLR(optimizer, args.lr_step, args.lr_factor)
            best_acc = 0
            for epoch in range(args.pretrain_epochs):
                lr_scheduler.step()
                print(lr_scheduler.get_lr())

                pretrain(train_source_iter, pretrained_model, criterion,
                         optimizer, epoch, args)
                source_val_acc = validate(val_source_loader, pretrained_model,
                                          criterion, None, args)

                # remember best acc and save checkpoint
                if source_val_acc['all'] > best_acc:
                    best_acc = source_val_acc['all']
                    torch.save({'model': pretrained_model.state_dict()},
                               args.pretrain)
                print("Source: {} best: {}".format(source_val_acc['all'],
                                                   best_acc))

        # load from the pretrained checkpoint
        pretrained_dict = torch.load(args.pretrain,
                                     map_location='cpu')['model']
        model_dict = model.state_dict()
        # remove keys from pretrained dict that doesn't appear in model dict
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model.load_state_dict(pretrained_dict, strict=False)
    else:
        # optionally resume from a checkpoint
        checkpoint = torch.load(args.resume, map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        optimizer_f.load_state_dict(checkpoint['optimizer_f'])
        optimizer_h.load_state_dict(checkpoint['optimizer_h'])
        optimizer_h_adv.load_state_dict(checkpoint['optimizer_h_adv'])
        lr_scheduler_f.load_state_dict(checkpoint['lr_scheduler_f'])
        lr_scheduler_h.load_state_dict(checkpoint['lr_scheduler_h'])
        lr_scheduler_h_adv.load_state_dict(checkpoint['lr_scheduler_h_adv'])
        start_epoch = checkpoint['epoch'] + 1

    # define visualization function
    tensor_to_image = Compose([
        Denormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ToPILImage()
    ])

    def visualize(image, keypoint2d, name, heatmaps=None):
        """
        Args:
            image (tensor): image in shape 3 x H x W
            keypoint2d (tensor): keypoints in shape K x 2
            name: name of the saving image
        """
        train_source_dataset.visualize(
            tensor_to_image(image), keypoint2d,
            logger.get_image_path("{}.jpg".format(name)))

    if args.phase == 'test':
        # evaluate on validation set
        source_val_acc = validate(val_source_loader, model, criterion, None,
                                  args)
        target_val_acc = validate(val_target_loader, model, criterion,
                                  visualize, args)
        print("Source: {:4.3f} Target: {:4.3f}".format(source_val_acc['all'],
                                                       target_val_acc['all']))
        for name, acc in target_val_acc.items():
            print("{}: {:4.3f}".format(name, acc))
        return

    # start training
    best_acc = 0
    print("Start regression domain adaptation.")
    for epoch in range(start_epoch, args.epochs):
        logger.set_epoch(epoch)
        print(lr_scheduler_f.get_lr(), lr_scheduler_h.get_lr(),
              lr_scheduler_h_adv.get_lr())

        # train for one epoch
        train(train_source_iter, train_target_iter, model, criterion,
              regression_disparity, optimizer_f, optimizer_h, optimizer_h_adv,
              lr_scheduler_f, lr_scheduler_h, lr_scheduler_h_adv, epoch,
              visualize if args.debug else None, args)

        # evaluate on validation set
        source_val_acc = validate(val_source_loader, model, criterion, None,
                                  args)
        target_val_acc = validate(val_target_loader, model, criterion,
                                  visualize if args.debug else None, args)

        # remember best acc and save checkpoint
        torch.save(
            {
                'model': model.state_dict(),
                'optimizer_f': optimizer_f.state_dict(),
                'optimizer_h': optimizer_h.state_dict(),
                'optimizer_h_adv': optimizer_h_adv.state_dict(),
                'lr_scheduler_f': lr_scheduler_f.state_dict(),
                'lr_scheduler_h': lr_scheduler_h.state_dict(),
                'lr_scheduler_h_adv': lr_scheduler_h_adv.state_dict(),
                'epoch': epoch,
                'args': args
            }, logger.get_checkpoint_path(epoch))
        if target_val_acc['all'] > best_acc:
            shutil.copy(logger.get_checkpoint_path(epoch),
                        logger.get_checkpoint_path('best'))
            best_acc = target_val_acc['all']
        print("Source: {:4.3f} Target: {:4.3f} Target(best): {:4.3f}".format(
            source_val_acc['all'], target_val_acc['all'], best_acc))
        for name, acc in target_val_acc.items():
            print("{}: {:4.3f}".format(name, acc))

    logger.close()
def main():
    global args
    args = parser.parse_args()

    if torch.cuda.is_available():
        cudnn.benchmark = True
    # model definition
    if args.arch == 'senet152':
        model = model_zoo[args.arch]()
        state_dict = torch.load(
            'D:\\PycharmWorkspace\\Torch-Texture-Classification\\se_resnet152-d17c99b7.pth'
        )
        new_state_dict = convert_state_dict_for_seresnet(state_dict)
        model.load_state_dict(new_state_dict)
    elif args.arch == 'se_resnet50':
        model = model_zoo[args.arch]()
        state_dict = torch.load(
            'D:\\PycharmWorkspace\\Torch-Texture-Classification\\seresnet50-60a8950a85b2b.pkl'
        )
        model.load_state_dict(state_dict)
    elif args.arch == 'nts':
        baseblock_arch = 'resnet50'
        model = attention_net(arch=baseblock_arch,
                              topN=4,
                              num_classes=args.num_classes)
        # state_dict = torch.load('D:\\PycharmWorkspace\\Torch-Texture-Classification\\nts50.pth')['net_state_dict']
        # model.load_state_dict(state_dict)
    else:
        model = model_zoo[args.arch](pretrained=True)
    if 'resnet' in args.arch:
        model.fc = nn.Linear(model.fc.in_features, args.num_classes)
    elif 'vgg' in args.arch:
        model.classifier[6] = nn.Linear(model.classifier[6].in_features,
                                        args.num_classes)
    elif 'senet' in args.arch:
        model.fc = nn.Linear(model.fc.in_features, args.num_classes)
    elif 'se_' in args.arch:
        model.fc = nn.Linear(model.fc.in_features, args.num_classes)
    elif 'inception' in args.arch:
        model.fc = nn.Linear(model.fc.in_features, args.num_classes)
    elif 'nts' in args.arch:
        pass  # done in NTS/core/model.py __init__ of attention_net
    print(model)
    # exit()
    # resume checkpoint
    checkpoint = None
    # if args.resume:
    #     device = torch.cuda.current_device()
    #     checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda(device))
    #     state = convert_state_dict(checkpoint['model_state'])
    #     model.load_state_dict(state)

    model.cuda()
    model = freeze_params(args.arch, model)
    print('trainable parameters:')
    for param in model.named_parameters():
        if param[1].requires_grad:
            print(param[0])  # [0] name, [1] params

    # criterion = CrossEntropyLabelSmooth(args.num_classes, epsilon=0.1)
    criterion = LabelSmoothingLoss(args.num_classes, smoothing=0.1)
    # criterion = nn.CrossEntropyLoss().cuda()

    # no bias decay
    param_optimizer = list(
        filter(lambda p: p.requires_grad, model.parameters()))
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay) and p.requires_grad
        ],
        'weight_decay':
        0.001
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay) and p.requires_grad
        ],
        'weight_decay':
        0.0
    }]
    optimizer = torch.optim.SGD(optimizer_grouped_parameters,
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=5e-4)

    # original optimizer
    # optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),
    #                             lr=args.lr, momentum=0.9, weight_decay=5e-4)

    # params should be a dict or a iterable tensor
    # optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
    #                              lr=args.lr, weight_decay=0.01)
    if args.resume:
        optimizer.load_state_dict(checkpoint['optimizer_state'])
    # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    # calculated by transforms_utils.py
    normalize = transforms.Normalize(mean=[0.5335619, 0.47571668, 0.4280075],
                                     std=[0.26906276, 0.2592897, 0.26745376])
    transform_train = transforms.Compose([
        transforms.Resize(args.resize),  # 384, 256
        transforms.RandomCrop(args.crop),  # 320, 224
        transforms.RandomHorizontalFlip(),
        # transforms.RandomVerticalFlip(),
        transforms.ColorJitter(0.4, 0.4, 0.4),
        # transforms.RandomRotation(45),
        transforms.ToTensor(),
        normalize,
    ])
    transform_test = transforms.Compose([
        transforms.Resize(args.resize),  # 384
        transforms.RandomCrop(args.crop),  # 320
        transforms.ToTensor(),
        normalize,
    ])

    train_data_root = 'D:\\PycharmWorkspace\\Torch-Texture-Classification\\dataset\\train'
    test_data_root = 'D:\\PycharmWorkspace\\Torch-Texture-Classification\\dataset\\test'
    train_dataset = TextureDataset(train_data_root,
                                   train=True,
                                   transform=transform_train)
    test_dataset = TextureDataset(test_data_root,
                                  train=False,
                                  transform=transform_test)
    sampler = get_sampler(train_dataset, args.num_classes)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.BATCH,
                              shuffle=False,
                              sampler=sampler,
                              num_workers=0,
                              pin_memory=True)  # num_workers != 0 will broke
    test_loader = DataLoader(test_dataset,
                             batch_size=args.BATCH,
                             shuffle=False,
                             num_workers=0,
                             pin_memory=False)
    print('train data length:', len(train_loader), '. test data length:',
          len(test_loader), '\n')

    lr_scheduler = MultiStepLR(optimizer,
                               milestones=args.decay_epoch,
                               gamma=args.gamma,
                               last_epoch=args.start_epoch)

    time_ = time.localtime(time.time())
    log_save_path = 'Logs/' + str(args.arch) + '_' + str(args.lr) + '_' + str(args.BATCH) + \
                    '_' + str(time_.tm_mon) + str(time_.tm_mday) + str(time_.tm_hour) + str(time_.tm_min) + \
                    '.log.txt'
    log_saver = open(log_save_path, mode='w')
    v_args = vars(args)
    for k, v in v_args.items():
        log_saver.writelines(str(k) + ' ' + str(v) + '\n')
    log_saver.close()

    global writer
    current_time = datetime.now().strftime('%b%d_%H-%M')
    logdir = os.path.join('TensorBoardXLog', current_time)
    writer = SummaryWriter(log_dir=logdir)

    # dummy_input = torch.randn(args.BATCH, 3, args.crop, args.crop).cuda()
    # writer.add_graph(model, dummy_input)
    best_score = 0
    for epoch in range(args.start_epoch + 1,
                       args.EPOCHS):  # args.start_epoch = -1 for MultistepLr
        log_saver = open(log_save_path, mode='a')
        lr_scheduler.step()
        train(train_loader,
              model,
              criterion,
              lr_scheduler,
              epoch,
              warm_up=False)
        prec1, loss = validate(test_loader, model, criterion)
        writer.add_scalar('scalar/test_prec', prec1, epoch)
        writer.add_scalar('scalar/test_loss', loss, epoch)
        print('test average is: ', prec1)
        log_saver.writelines('learning rate:' + str(lr_scheduler.get_lr()[0]) +
                             ', epoch:' + str(epoch) + ', test average is: ' +
                             str(prec1) + ', loss average is: ' + str(loss) +
                             '\n')

        save_name = str(args.lr) + '_' + str(args.BATCH)
        save_dir = os.path.join(
            args.save_dir_path,
            str(args.arch) + '_' + str(time_.tm_mon) + str(time_.tm_mday) +
            str(time_.tm_hour) + str(time_.tm_min))
        if prec1 > best_score:
            best_score = prec1
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'model_state': model.cpu().state_dict(),
                    'optimizer_state': optimizer.state_dict(),
                }, save_dir, save_name + '_ckpt_e{}'.format(epoch))
        log_saver.close()
    writer.close()
Example #8
0
        milestones = [int(v.strip()) for v in args.milestones.split(",")]
        scheduler = MultiStepLR(optimizer, milestones=milestones,
                                                     gamma=0.1, last_epoch=last_epoch)
    elif args.scheduler == 'cosine':
        logging.info("Uses CosineAnnealingLR scheduler.")
        scheduler = CosineAnnealingLR(optimizer, args.t_max, last_epoch=last_epoch)
    else:
        logging.fatal(f"Unsupported Scheduler: {args.scheduler}.")
        parser.print_help(sys.stderr)
        sys.exit(1)

    last_epoch += 1
    logging.info(f"Start training from epoch {last_epoch + 1}.")
    for epoch in range(last_epoch, args.num_epochs):
        scheduler.step()
        lr = scheduler.get_lr()
        if len(lr) == 1:
            writer.add_scalar("learn_rate/pred_heads_lr", lr[0], epoch)
        elif len(lr) == 2:
            writer.add_scalar("learn_rate/ssd_lr", lr[0], epoch)
            writer.add_scalar("learn_rate/pred_heads_lr", lr[1], epoch)
        else:
            writer.add_scalar("learn_rate/mob_lr", lr[0], epoch)
            writer.add_scalar("learn_rate/ssd_lr", lr[1], epoch)
            writer.add_scalar("learn_rate/pred_heads_lr", lr[2], epoch)

        train(train_loader, net, criterion, optimizer,
              device=DEVICE, debug_steps=args.debug_steps, epoch=epoch)

        if epoch == 0:
            continue
def train_model_residual_lowlight_twostage_gan_best():

    #设置超参数
    batchsize = 128
    init_lr = 0.001
    K_adjacent_band = 36
    display_step = 20
    display_band = 20
    is_resume = False
    lambda_recon = 10

    start_epoch = 1

    device = DEVICE
    #准备数据
    train_set = HsiCubicTrainDataset('./data/train_lowlight/')
    print('total training example:', len(train_set))

    train_loader = DataLoader(dataset=train_set,
                              batch_size=batchsize,
                              shuffle=True)

    #加载测试label数据
    mat_src_path = './data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat'
    test_label_hsi = scio.loadmat(mat_src_path)['label']

    #加载测试数据
    test_batch_size = 1
    test_data_dir = './data/test_lowlight/cubic/'
    test_set = HsiCubicLowlightTestDataset(test_data_dir)
    test_dataloader = DataLoader(dataset=test_set,
                                 batch_size=test_batch_size,
                                 shuffle=False)

    batch_size, channel, width, height = next(iter(test_dataloader))[0].shape

    band_num = len(test_dataloader)
    denoised_hsi = np.zeros((width, height, band_num))

    #创建模型
    net = HSIDDenseNetTwoStage(K_adjacent_band)
    init_params(net)
    #net = nn.DataParallel(net).to(device)
    net = net.to(device)

    #创建discriminator
    disc = DiscriminatorABC(2, 4)
    init_params(disc)
    disc = disc.to(device)
    disc_opt = torch.optim.Adam(disc.parameters(), lr=init_lr)

    num_epoch = 100
    print('epoch count == ', num_epoch)

    #创建优化器
    #hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE, betas=(0.9, 0,999))
    hsid_optimizer = optim.Adam(net.parameters(), lr=init_lr)

    #Scheduler
    scheduler = MultiStepLR(hsid_optimizer, milestones=[40, 60, 80], gamma=0.1)
    warmup_epochs = 3
    #scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(hsid_optimizer, num_epoch-warmup_epochs+40, eta_min=1e-7)
    #scheduler = GradualWarmupScheduler(hsid_optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine)
    #scheduler.step()

    #唤醒训练
    if is_resume:
        model_dir = './checkpoints'
        path_chk_rest = dir_utils.get_last_path(model_dir, 'model_latest.pth')
        model_utils.load_checkpoint(net, path_chk_rest)
        start_epoch = model_utils.load_start_epoch(path_chk_rest) + 1
        model_utils.load_optim(hsid_optimizer, path_chk_rest)
        model_utils.load_disc_checkpoint(disc, path_chk_rest)
        model_utils.load_disc_optim(disc_opt, path_chk_rest)

        for i in range(1, start_epoch):
            scheduler.step()
        new_lr = scheduler.get_lr()[0]
        print(
            '------------------------------------------------------------------------------'
        )
        print("==> Resuming Training with learning rate:", new_lr)
        print(
            '------------------------------------------------------------------------------'
        )

    #定义loss 函数
    #criterion = nn.MSELoss()

    global tb_writer
    tb_writer = get_summary_writer(log_dir='logs')

    gen_epoch_loss_list = []

    cur_step = 0

    first_batch = next(iter(train_loader))

    best_psnr = 0
    best_epoch = 0
    best_iter = 0

    for epoch in range(start_epoch, num_epoch + 1):
        epoch_start_time = time.time()
        scheduler.step()
        #print(epoch, 'lr={:.6f}'.format(scheduler.get_last_lr()[0]))
        print('epoch = ', epoch, 'lr={:.6f}'.format(scheduler.get_lr()[0]))
        print(scheduler.get_lr())
        gen_epoch_loss = 0

        net.train()
        #for batch_idx, (noisy, label) in enumerate([first_batch] * 300):
        for batch_idx, (noisy, cubic, label) in enumerate(train_loader):
            #print('batch_idx=', batch_idx)
            noisy = noisy.to(device)
            label = label.to(device)
            cubic = cubic.to(device)

            ### Update discriminator ###
            disc_opt.zero_grad(
            )  # Zero out the gradient before backpropagation
            with torch.no_grad():
                fake, fake_stage2 = net(noisy, cubic)
            #print('noisy shape =', noisy.shape, fake_stage2.shape)
            #fake.detach()
            disc_fake_hat = disc(fake_stage2.detach() + noisy,
                                 noisy)  # Detach generator
            disc_fake_loss = adv_criterion(disc_fake_hat,
                                           torch.zeros_like(disc_fake_hat))
            disc_real_hat = disc(label, noisy)
            disc_real_loss = adv_criterion(disc_real_hat,
                                           torch.ones_like(disc_real_hat))
            disc_loss = (disc_fake_loss + disc_real_loss) / 2
            disc_loss.backward(retain_graph=True)  # Update gradients
            disc_opt.step()  # Update optimizer

            ### Update generator ###
            hsid_optimizer.zero_grad()
            #denoised_img = net(noisy, cubic)
            #loss = loss_fuction(denoised_img, label)

            residual, residual_stage2 = net(noisy, cubic)
            disc_fake_hat = disc(residual_stage2 + noisy, noisy)
            gen_adv_loss = adv_criterion(disc_fake_hat,
                                         torch.ones_like(disc_fake_hat))

            alpha = 0.2
            beta = 0.2
            rec_loss = beta * (alpha*loss_fuction(residual, label-noisy) + (1-alpha) * recon_criterion(residual, label-noisy)) \
             + (1-beta) * (alpha*loss_fuction(residual_stage2, label-noisy) + (1-alpha) * recon_criterion(residual_stage2, label-noisy))

            loss = gen_adv_loss + lambda_recon * rec_loss

            loss.backward()  # calcu gradient
            hsid_optimizer.step()  # update parameter

            gen_epoch_loss += loss.item()

            if cur_step % display_step == 0:
                if cur_step > 0:
                    print(
                        f"Epoch {epoch}: Step {cur_step}: Batch_idx {batch_idx}: MSE loss: {loss.item()}"
                    )
                    print(
                        f"rec_loss =  {rec_loss.item()}, gen_adv_loss = {gen_adv_loss.item()}"
                    )

                else:
                    print("Pretrained initial state")

            tb_writer.add_scalar("MSE loss", loss.item(), cur_step)

            #step ++,每一次循环,每一个batch的处理,叫做一个step
            cur_step += 1

        gen_epoch_loss_list.append(gen_epoch_loss)
        tb_writer.add_scalar("mse epoch loss", gen_epoch_loss, epoch)

        #scheduler.step()
        #print("Decaying learning rate to %g" % scheduler.get_last_lr()[0])

        torch.save(
            {
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
                'disc': disc.state_dict(),
                'disc_opt': disc_opt.state_dict()
            }, f"checkpoints/two_stage_hsid_dense_gan_{epoch}.pth")

        #测试代码
        net.eval()
        for batch_idx, (noisy_test, cubic_test,
                        label_test) in enumerate(test_dataloader):
            noisy_test = noisy_test.type(torch.FloatTensor)
            label_test = label_test.type(torch.FloatTensor)
            cubic_test = cubic_test.type(torch.FloatTensor)

            noisy_test = noisy_test.to(DEVICE)
            label_test = label_test.to(DEVICE)
            cubic_test = cubic_test.to(DEVICE)

            with torch.no_grad():

                residual, residual_stage2 = net(noisy_test, cubic_test)
                denoised_band = noisy_test + residual_stage2

                denoised_band_numpy = denoised_band.cpu().numpy().astype(
                    np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:, :, batch_idx] = denoised_band_numpy

                if batch_idx == 49:
                    residual_squeezed = torch.squeeze(residual, axis=0)
                    residual_stage2_squeezed = torch.squeeze(residual_stage2,
                                                             axis=0)
                    denoised_band_squeezed = torch.squeeze(denoised_band,
                                                           axis=0)
                    label_test_squeezed = torch.squeeze(label_test, axis=0)
                    noisy_test_squeezed = torch.squeeze(noisy_test, axis=0)
                    tb_writer.add_image(f"images/{epoch}_restored",
                                        denoised_band_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_residual",
                                        residual_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_residual_stage2",
                                        residual_stage2_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_label",
                                        label_test_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_noisy",
                                        noisy_test_squeezed,
                                        1,
                                        dataformats='CHW')

        psnr = PSNR(denoised_hsi, test_label_hsi)
        ssim = SSIM(denoised_hsi, test_label_hsi)
        sam = SAM(denoised_hsi, test_label_hsi)

        #计算pnsr和ssim
        print("=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".
              format(psnr, ssim, sam))
        tb_writer.add_scalars("validation metrics", {
            'average PSNR': psnr,
            'average SSIM': ssim,
            'avarage SAM': sam
        }, epoch)  #通过这个我就可以看到,那个epoch的性能是最好的

        #保存best模型
        if psnr > best_psnr:
            best_psnr = psnr
            best_epoch = epoch
            best_iter = cur_step
            torch.save(
                {
                    'epoch': epoch,
                    'gen': net.state_dict(),
                    'gen_opt': hsid_optimizer.state_dict(),
                    'disc': disc.state_dict(),
                    'disc_opt': disc_opt.state_dict()
                }, f"checkpoints/two_stage_hsid_dense_gan_best.pth")

        print(
            "[epoch %d it %d PSNR: %.4f --- best_epoch %d best_iter %d Best_PSNR %.4f]"
            % (epoch, cur_step, psnr, best_epoch, best_iter, best_psnr))

        print(
            "------------------------------------------------------------------"
        )
        print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".
              format(epoch,
                     time.time() - epoch_start_time, gen_epoch_loss,
                     scheduler.get_lr()[0]))
        print(
            "------------------------------------------------------------------"
        )

        torch.save(
            {
                'epoch': epoch,
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
                'disc': disc.state_dict(),
                'disc_opt': disc_opt.state_dict()
            }, os.path.join('./checkpoints', "model_latest.pth"))

    tb_writer.close()
def train_model_residual_lowlight_rdn():

    device = DEVICE
    #准备数据
    train = np.load('./data/denoise/train_washington8.npy')
    train = train.transpose((2, 1, 0))

    test = np.load('./data/denoise/train_washington8.npy')
    #test=test.transpose((2,1,0))
    test = test.transpose((2, 1, 0))  #将通道维放在最前面

    save_model_path = './checkpoints/hsirnd_denoise_l1loss'
    if not os.path.exists(save_model_path):
        os.mkdir(save_model_path)

    #创建模型
    net = HSIRDNECA_Denoise(K)
    init_params(net)
    net = nn.DataParallel(net).to(device)
    #net = net.to(device)

    #创建优化器
    #hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE, betas=(0.9, 0,999))
    hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE)
    scheduler = MultiStepLR(hsid_optimizer, milestones=[200, 400], gamma=0.5)

    #定义loss 函数
    #criterion = nn.MSELoss()

    gen_epoch_loss_list = []

    cur_step = 0

    best_psnr = 0
    best_epoch = 0
    best_iter = 0
    start_epoch = 1
    num_epoch = 600

    mpsnr_list = []
    for epoch in range(start_epoch, num_epoch + 1):
        epoch_start_time = time.time()
        scheduler.step()
        print('epoch = ', epoch, 'lr={:.6f}'.format(scheduler.get_lr()[0]))
        print(scheduler.get_lr())

        gen_epoch_loss = 0

        net.train()

        channels = 191  # 191 channels
        data_patches, data_cubic_patches = datagenerator(train, channels)

        data_patches = torch.from_numpy(data_patches.transpose((
            0,
            3,
            1,
            2,
        )))
        data_cubic_patches = torch.from_numpy(
            data_cubic_patches.transpose((0, 4, 1, 2, 3)))

        DDataset = DenoisingDataset(data_patches, data_cubic_patches, SIGMA)

        print('yes')
        DLoader = DataLoader(dataset=DDataset,
                             batch_size=BATCH_SIZE,
                             shuffle=True)  # loader出问题了

        epoch_loss = 0
        start_time = time.time()

        #for batch_idx, (noisy, label) in enumerate([first_batch] * 300):
        for step, x_y in enumerate(DLoader):
            #print('batch_idx=', batch_idx)
            batch_x_noise, batch_y_noise, batch_x = x_y[0], x_y[1], x_y[2]

            batch_x_noise = batch_x_noise.to(device)
            batch_y_noise = batch_y_noise.to(device)
            batch_x = batch_x.to(device)

            hsid_optimizer.zero_grad()
            #denoised_img = net(noisy, cubic)
            #loss = loss_fuction(denoised_img, label)

            residual = net(batch_x_noise, batch_y_noise)
            alpha = 0.8
            loss = recon_criterion(residual, batch_x - batch_x_noise)
            #loss = alpha*recon_criterion(residual, label-noisy) + (1-alpha)*loss_function_mse(residual, label-noisy)
            #loss = recon_criterion(residual, label-noisy)
            loss.backward()  # calcu gradient
            hsid_optimizer.step()  # update parameter

            if step % 10 == 0:
                print('%4d %4d / %4d loss = %2.8f' %
                      (epoch + 1, step, data_patches.size(0) // BATCH_SIZE,
                       loss.item() / BATCH_SIZE))

        #scheduler.step()
        #print("Decaying learning rate to %g" % scheduler.get_last_lr()[0])

        torch.save(
            {
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
            },
            f"{save_model_path}/hsid_rdn_eca_l1_loss_600epoch_patchsize32_{epoch}.pth"
        )

        #测试代码
        net.eval()
        """
        channel_s = 191  # 设置多少波段
        data_patches, data_cubic_patches = datagenerator(test, channel_s)

        data_patches = torch.from_numpy(data_patches.transpose((0, 3, 1, 2,)))
        data_cubic_patches = torch.from_numpy(data_cubic_patches.transpose((0, 4, 1, 2, 3)))

        DDataset = DenoisingDataset(data_patches, data_cubic_patches, SIGMA)
        DLoader = DataLoader(dataset=DDataset, batch_size=BATCH_SIZE, shuffle=True)
        epoch_loss = 0
        
        for step, x_y in enumerate(DLoader):
            batch_x_noise, batch_y_noise, batch_x = x_y[0], x_y[1], x_y[2]

            batch_x_noise = batch_x_noise.to(DEVICE)
            batch_y_noise = batch_y_noise.to(DEVICE)
            batch_x = batch_x.to(DEVICE)
            residual = net(batch_x_noise, batch_y_noise)

            loss = loss_fuction(residual, batch_x-batch_x_noise)

            epoch_loss += loss.item()

            if step % 10 == 0:
                print('%4d %4d / %4d test loss = %2.4f' % (
                    epoch + 1, step, data_patches.size(0) // BATCH_SIZE, loss.item() / BATCH_SIZE))
        """
        #加载数据
        test_data_dir = './data/denoise/test/'
        test_set = HsiTrainDataset(test_data_dir)

        test_dataloader = DataLoader(test_set, batch_size=1, shuffle=False)

        #指定结果输出路径
        test_result_output_path = './data/denoise/testresult/'
        if not os.path.exists(test_result_output_path):
            os.makedirs(test_result_output_path)

        #逐个通道的去噪
        """
        分配一个numpy数组,存储去噪后的结果
        遍历所有通道,
        对于每个通道,通过get_adjacent_spectral_bands获取其相邻的K个通道
        调用hsid进行预测
        将预测到的residual和输入的noise加起来,得到输出band

        将去噪后的结果保存成mat结构
        """
        psnr_list = []
        for batch_idx, (noisy, label) in enumerate(test_dataloader):
            noisy = noisy.type(torch.FloatTensor)
            label = label.type(torch.FloatTensor)

            batch_size, width, height, band_num = noisy.shape
            denoised_hsi = np.zeros((width, height, band_num))

            noisy = noisy.to(DEVICE)
            label = label.to(DEVICE)

            with torch.no_grad():
                for i in range(band_num):  #遍历每个band去处理
                    current_noisy_band = noisy[:, :, :, i]
                    current_noisy_band = current_noisy_band[:, None]

                    adj_spectral_bands = get_adjacent_spectral_bands(
                        noisy, K, i)
                    #adj_spectral_bands = torch.transpose(adj_spectral_bands,3,1) #将通道数置换到第二维
                    adj_spectral_bands = adj_spectral_bands.permute(0, 3, 1, 2)
                    adj_spectral_bands_unsqueezed = adj_spectral_bands.unsqueeze(
                        1)
                    #print(current_noisy_band.shape, adj_spectral_bands.shape)
                    residual = net(current_noisy_band,
                                   adj_spectral_bands_unsqueezed)
                    denoised_band = residual + current_noisy_band
                    denoised_band_numpy = denoised_band.cpu().numpy().astype(
                        np.float32)
                    denoised_band_numpy = np.squeeze(denoised_band_numpy)

                    denoised_hsi[:, :, i] += denoised_band_numpy

                    test_label_current_band = label[:, :, :, i]

                    label_band_numpy = test_label_current_band.cpu().numpy(
                    ).astype(np.float32)
                    label_band_numpy = np.squeeze(label_band_numpy)

                    #print(denoised_band_numpy.shape, label_band_numpy.shape, label.shape)
                    psnr = PSNR(denoised_band_numpy, label_band_numpy)
                    psnr_list.append(psnr)

            mpsnr = np.mean(psnr_list)
            mpsnr_list.append(mpsnr)

            denoised_hsi_trans = denoised_hsi.transpose(2, 0, 1)
            test_label_hsi_trans = np.squeeze(label.cpu().numpy().astype(
                np.float32)).transpose(2, 0, 1)
            mssim = SSIM(denoised_hsi_trans, test_label_hsi_trans)
            sam = SAM(denoised_hsi_trans, test_label_hsi_trans)

            #计算pnsr和ssim
            print(
                "=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".
                format(mpsnr, mssim, sam))

        #保存best模型
        if mpsnr > best_psnr:
            best_psnr = mpsnr
            best_epoch = epoch
            best_iter = cur_step
            torch.save(
                {
                    'epoch': epoch,
                    'gen': net.state_dict(),
                    'gen_opt': hsid_optimizer.state_dict(),
                },
                f"{save_model_path}/hsid_rdn_eca_l1_loss_600epoch_patchsize32_best.pth"
            )

        print(
            "[epoch %d it %d PSNR: %.4f --- best_epoch %d best_iter %d Best_PSNR %.4f]"
            % (epoch, cur_step, mpsnr, best_epoch, best_iter, best_psnr))

        print(
            "------------------------------------------------------------------"
        )
        print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".
              format(epoch,
                     time.time() - epoch_start_time, gen_epoch_loss,
                     INIT_LEARNING_RATE))
        print(
            "------------------------------------------------------------------"
        )
def train_model_residual_lowlight_rdn():

    device = DEVICE
    #准备数据
    train_set = HsiCubicTrainDataset('./data/train_lowlight_patchsize32_K24/')
    #print('trainset32 training example:', len(train_set32))
    #train_set = HsiCubicTrainDataset('./data/train_lowlight/')

    #train_set_64 = HsiCubicTrainDataset('./data/train_lowlight_patchsize64/')

    #train_set_list = [train_set32, train_set_64]
    #train_set = ConcatDataset(train_set_list) #里面的样本大小必须是一致的,否则会连接失败
    print('total training example:', len(train_set))

    train_loader = DataLoader(dataset=train_set, batch_size=BATCH_SIZE, shuffle=True)
    
    #加载测试label数据
    mat_src_path = './data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat'
    test_label_hsi = scio.loadmat(mat_src_path)['label']

    #加载测试数据
    batch_size = 1
    #test_data_dir = './data/test_lowlight/cuk12/'
    test_data_dir = './data/test_lowlight/cuk12/'

    test_set = HsiCubicLowlightTestDataset(test_data_dir)
    test_dataloader = DataLoader(dataset=test_set, batch_size=batch_size, shuffle=False)

    batch_size, channel, width, height = next(iter(test_dataloader))[0].shape
    
    band_num = len(test_dataloader)
    denoised_hsi = np.zeros((width, height, band_num))

    save_model_path = './checkpoints/hsirnd_without_multiscale_and_eca'
    if not os.path.exists(save_model_path):
        os.mkdir(save_model_path)

    #创建模型
    net = HSIRDNECAWithoutMultiScaleECA(K)
    init_params(net)
    net = nn.DataParallel(net).to(device)
    #net = net.to(device)

    #创建优化器
    #hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE, betas=(0.9, 0,999))
    hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE)
    scheduler = MultiStepLR(hsid_optimizer, milestones=[200,400], gamma=0.5)

    #定义loss 函数
    #criterion = nn.MSELoss()
    best_psnr = 0

    is_resume = RESUME
    #唤醒训练
    if is_resume:
        path_chk_rest    = dir_utils.get_last_path(save_model_path, 'model_latest.pth')
        model_utils.load_checkpoint(net,path_chk_rest)
        start_epoch = model_utils.load_start_epoch(path_chk_rest) + 1
        model_utils.load_optim(hsid_optimizer, path_chk_rest)
        best_psnr = model_utils.load_best_psnr(path_chk_rest)

        for i in range(1, start_epoch):
            scheduler.step()
        new_lr = scheduler.get_lr()[0]
        print('------------------------------------------------------------------------------')
        print("==> Resuming Training with learning rate:", new_lr)
        print('------------------------------------------------------------------------------')

    global tb_writer
    tb_writer = get_summary_writer(log_dir='logs')

    gen_epoch_loss_list = []

    cur_step = 0

    first_batch = next(iter(train_loader))

    best_epoch = 0
    best_iter = 0
    if not is_resume:
        start_epoch = 1
    num_epoch = 600

    mpsnr_list = []
    for epoch in range(start_epoch, num_epoch+1):
        epoch_start_time = time.time()
        scheduler.step()
        print('epoch = ', epoch, 'lr={:.6f}'.format(scheduler.get_lr()[0]))
        print(scheduler.get_lr())

        gen_epoch_loss = 0

        net.train()
        #for batch_idx, (noisy, label) in enumerate([first_batch] * 300):
        for batch_idx, (noisy, cubic, label) in enumerate(train_loader):
            #print('batch_idx=', batch_idx)
            noisy = noisy.to(device)
            label = label.to(device)
            cubic = cubic.to(device)

            hsid_optimizer.zero_grad()
            #denoised_img = net(noisy, cubic)
            #loss = loss_fuction(denoised_img, label)

            residual = net(noisy, cubic)
            alpha = 0.8
            loss = recon_criterion(residual, label-noisy)
            #loss = alpha*recon_criterion(residual, label-noisy) + (1-alpha)*loss_function_mse(residual, label-noisy)
            #loss = recon_criterion(residual, label-noisy)
            loss.backward() # calcu gradient
            hsid_optimizer.step() # update parameter

            gen_epoch_loss += loss.item()

            if cur_step % display_step == 0:
                if cur_step > 0:
                    print(f"Epoch {epoch}: Step {cur_step}: Batch_idx {batch_idx}: MSE loss: {loss.item()}")
                else:
                    print("Pretrained initial state")

            tb_writer.add_scalar("MSE loss", loss.item(), cur_step)

            #step ++,每一次循环,每一个batch的处理,叫做一个step
            cur_step += 1


        gen_epoch_loss_list.append(gen_epoch_loss)
        tb_writer.add_scalar("mse epoch loss", gen_epoch_loss, epoch)

        #scheduler.step()
        #print("Decaying learning rate to %g" % scheduler.get_last_lr()[0])
 
        torch.save({
            'gen': net.state_dict(),
            'gen_opt': hsid_optimizer.state_dict(),
        }, f"{save_model_path}/hsid_rdn_eca_without_multiscale_eca_l1_loss_600epoch_patchsize32_{epoch}.pth")

        #测试代码
        net.eval()
        psnr_list = []

        for batch_idx, (noisy_test, cubic_test, label_test) in enumerate(test_dataloader):
            noisy_test = noisy_test.type(torch.FloatTensor)
            label_test = label_test.type(torch.FloatTensor)
            cubic_test = cubic_test.type(torch.FloatTensor)

            noisy_test = noisy_test.to(DEVICE)
            label_test = label_test.to(DEVICE)
            cubic_test = cubic_test.to(DEVICE)

            with torch.no_grad():

                residual = net(noisy_test, cubic_test)
                denoised_band = noisy_test + residual
                
                denoised_band_numpy = denoised_band.cpu().numpy().astype(np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:,:,batch_idx] = denoised_band_numpy

                if batch_idx == 49:
                    residual_squeezed = torch.squeeze(residual, axis=0)
                    denoised_band_squeezed = torch.squeeze(denoised_band, axis=0) 
                    label_test_squeezed = torch.squeeze(label_test,axis=0)
                    noisy_test_squeezed = torch.squeeze(noisy_test,axis=0)
                    tb_writer.add_image(f"images/{epoch}_restored", denoised_band_squeezed, 1, dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_residual", residual_squeezed, 1, dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_label", label_test_squeezed, 1, dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_noisy", noisy_test_squeezed, 1, dataformats='CHW')

            test_label_current_band = test_label_hsi[:,:,batch_idx]

            psnr = PSNR(denoised_band_numpy, test_label_current_band)
            psnr_list.append(psnr)
        
        mpsnr = np.mean(psnr_list)
        mpsnr_list.append(mpsnr)

        denoised_hsi_trans = denoised_hsi.transpose(2,0,1)
        test_label_hsi_trans = test_label_hsi.transpose(2, 0, 1)
        mssim = SSIM(denoised_hsi_trans, test_label_hsi_trans)
        sam = SAM(denoised_hsi_trans, test_label_hsi_trans)


        #计算pnsr和ssim
        print("=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".format(mpsnr, mssim, sam)) 
        tb_writer.add_scalars("validation metrics", {'average PSNR':mpsnr,
                        'average SSIM':mssim,
                        'avarage SAM': sam}, epoch) #通过这个我就可以看到,那个epoch的性能是最好的

        #保存best模型
        if mpsnr > best_psnr:
            best_psnr = mpsnr
            best_epoch = epoch
            best_iter = cur_step
            torch.save({
                'epoch' : epoch,
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
            }, f"{save_model_path}/hsid_rdn_eca_without_multiscale__eca_l1_loss_600epoch_patchsize32_best.pth")

        print("[epoch %d it %d PSNR: %.4f --- best_epoch %d best_iter %d Best_PSNR %.4f]" % (epoch, cur_step, mpsnr, best_epoch, best_iter, best_psnr))

        print("------------------------------------------------------------------")
        print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time()-epoch_start_time, gen_epoch_loss, INIT_LEARNING_RATE))
        print("------------------------------------------------------------------")

        #保存当前模型
        torch.save({'epoch': epoch, 
                    'gen': net.state_dict(),
                    'gen_opt': hsid_optimizer.state_dict(),
                    'best_psnr': best_psnr,
                    }, os.path.join(save_model_path,"model_latest.pth"))
    mpsnr_list_numpy = np.array(mpsnr_list)
    np.save(os.path.join(save_model_path, "mpsnr_per_epoch.npy"), mpsnr_list_numpy)
    tb_writer.close()
Example #12
0
def main():
    global args, MODELS_DIR
    print args

    if args.dbg:
        MODELS_DIR = join(MODELS_DIR, 'dbg')

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    # if torch.cuda.is_available() and not args.cuda:
    #     print("WARNING: You have a CUDA device, so you should probably run with --cuda")
    print 'CudNN:', torch.backends.cudnn.version()
    print 'Run on {} GPUs'.format(torch.cuda.device_count())
    cudnn.benchmark = True

    is_sobel = args.arch.endswith('Sobel')
    print 'is_sobel', is_sobel

    print("=> creating model '{}'".format(args.arch))
    model = models.__dict__[args.arch](
        num_classes=args.num_clusters if args.unsupervised else 1000,
        dropout7_prob=args.dropout7_prob)
    model = torch.nn.DataParallel(model).cuda()
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    experiment = "{}_lr{}_{}{}".format(
        args.arch, args.lr, 'unsup' if args.unsupervised else 'labels',
        '_v2' if args.imagenet_version == 2 else '')
    if args.unsupervised:
        experiment += '{sobel_norm}_nc{nc}_l{clustering_layer}_rec{rec_epoch}{reset_fc}'.format(
            sobel_norm='_normed' if args.sobel_normalized else '',
            nc=args.num_clusters,
            clustering_layer=args.clustering_layer,
            rec_epoch=args.recluster_epoch,
            reset_fc='_reset-fc' if args.reset_fc else '')

    checkpoint = None
    if args.output_dir is None:
        args.output_dir = join(MODELS_DIR, experiment + '_' + args.exp_suffix)

    if args.output_dir is not None and os.path.exists(args.output_dir):
        ckpt_path = join(
            args.output_dir, 'checkpoint.pth.tar'
            if not args.from_best else 'model_best.pth.tar')
        if not os.path.isfile(ckpt_path):
            print "=> no checkpoint found at '{}'\nUsing model_best.pth.tar".format(
                ckpt_path)
            ckpt_path = join(args.output_dir, 'model_best.pth.tar')

        if os.path.isfile(ckpt_path):
            print("=> loading checkpoint '{}'".format(ckpt_path))
            checkpoint = torch.load(ckpt_path)
            print("=> loaded checkpoint '{}' (epoch {})".format(
                ckpt_path, checkpoint['epoch']))
        else:
            print "=> no checkpoint found at '{}'\nUsing model_best_nmi.pth.tar".format(
                ckpt_path)
            ckpt_path = join(args.output_dir, 'model_best_nmi.pth.tar')

        if os.path.isfile(ckpt_path):
            print("=> loading checkpoint '{}'".format(ckpt_path))
            checkpoint = torch.load(ckpt_path)
            print("=> loaded checkpoint '{}' (epoch {})".format(
                ckpt_path, checkpoint['epoch']))
        else:
            print "=> no checkpoint found at '{}'".format(ckpt_path)
            ans = None
            while ans != 'y' and ans != 'n':
                ans = raw_input('Clear the dir {}? [y/n] '.format(
                    args.output_dir)).lower()
            if ans.lower() == 'y':
                shutil.rmtree(args.output_dir)
            else:
                print 'Just write in the same dir.'
                # raise IOError("=> no checkpoint found at '{}'".format(ckpt_path))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    print 'Output dir:', args.output_dir

    start_epoch = 0
    best_score = 0
    best_nmi = 0
    if checkpoint is not None:
        start_epoch = checkpoint['epoch']
        if 'best_score' in checkpoint:
            best_score = checkpoint['best_score']
        else:
            print 'WARNING! NO best "score_found" in checkpoint!'
            best_score = 0
        if 'nmi' in checkpoint:
            print 'Current NMI/GT:', checkpoint['nmi']
        if 'best_nmi' in checkpoint:
            best_nmi = checkpoint['best_nmi']
            print 'Best NMI/GT:', best_nmi
        print 'Best score:', best_score
        if 'cur_score' in checkpoint:
            print 'Current score:', checkpoint['cur_score']
        model.load_state_dict(checkpoint['state_dict'])
        print 'state dict loaded'
        optimizer.load_state_dict(checkpoint['optimizer'])
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.lr
            param_group['initial_lr'] = args.lr
    logger = SummaryWriter(log_dir=args.output_dir)

    ### Data loading ###
    num_gt_classes = 1000
    split_dirs = {
        'train':
        join(args.data,
             'train' if args.imagenet_version == 1 else 'train_256'),
        'val':
        join(
            args.data, 'val' if args.imagenet_version == 1 else 'val_256'
        )  # we get lower accuracy with cal_256, probably because of jpeg compression
    }
    dataset_indices = dict()
    for key in ['train', 'val']:
        index_path = join(args.data,
                          os.path.basename(split_dirs[key]) + '_index.json')

        if os.path.exists(index_path):
            with open(index_path) as json_file:
                dataset_indices[key] = json.load(json_file)
        else:
            print 'Indexing ' + key
            dataset_indices[key] = index_imagenet(split_dirs[key], index_path)

    assert dataset_indices['train']['class_to_idx'] == \
           dataset_indices['val']['class_to_idx']
    if args.dbg:
        max_images = 1000
        print 'DBG: WARNING! Trauncate train datset to {} images'.format(
            max_images)
        dataset_indices['train']['samples'] = dataset_indices['train'][
            'samples'][:max_images]
        dataset_indices['val']['samples'] = dataset_indices['val'][
            'samples'][:max_images]

    num_workers = args.workers  # if args.unsupervised else max(1, args.workers / 2)

    print '[TRAIN]...'
    if args.unsupervised:
        train_loader_gt = create_data_loader(
            split_dirs['train'],
            dataset_indices['train'],
            is_sobel,
            sobel_normalized=args.sobel_normalized,
            aug='random_crop_flip',
            shuffle='shuffle' if not args.fast_dataflow else 'shuffle_buffer',
            num_workers=num_workers,
            use_fast_dataflow=args.fast_dataflow,
            buffer_size=args.buffer_size)
        eval_gt_aug = '10_crop'
        val_loader_gt = create_data_loader(
            split_dirs['val'],
            dataset_indices['val'],
            is_sobel,
            sobel_normalized=args.sobel_normalized,
            aug=eval_gt_aug,
            batch_size=26,  # WARNING. Decrease the batch size because of Memory
            shuffle='shuffle',
            num_workers=num_workers,
            use_fast_dataflow=False,
            buffer_size=args.buffer_size)
    else:
        train_loader = create_data_loader(
            split_dirs['train'],
            dataset_indices['train'],
            is_sobel,
            sobel_normalized=args.sobel_normalized,
            aug='random_crop_flip',
            shuffle='shuffle' if not args.fast_dataflow else 'shuffle_buffer',
            num_workers=num_workers,
            use_fast_dataflow=args.fast_dataflow,
            buffer_size=args.buffer_size)
        print '[VAL]...'
        # with GT labels!
        val_loader = create_data_loader(
            split_dirs['val'],
            dataset_indices['val'],
            is_sobel,
            sobel_normalized=args.sobel_normalized,
            aug='central_crop',
            batch_size=args.batch_size,
            shuffle='shuffle' if not args.fast_dataflow else None,
            num_workers=num_workers,
            use_fast_dataflow=args.fast_dataflow,
            buffer_size=args.buffer_size)
    ###############################################################################

    # StepLR(optimizer, step_size=args.decay_step, gamma=args.decay_gamma)
    if args.scheduler == 'multi_step':
        scheduler = MultiStepLR(optimizer,
                                milestones=[30, 60, 80],
                                gamma=args.decay_gamma)
    elif args.scheduler == 'multi_step2':
        scheduler = MultiStepLR(optimizer,
                                milestones=[50, 100],
                                gamma=args.decay_gamma)
    elif args.scheduler == 'cyclic':
        print 'Using Cyclic LR!'
        cyclic_lr = CyclicLr(start_epoch if args.reset_lr else 0,
                             init_lr=args.lr,
                             num_epochs_per_cycle=args.cycle,
                             epochs_pro_decay=args.decay_step,
                             lr_decay_factor=args.decay_gamma)
        scheduler = LambdaLR(optimizer, lr_lambda=cyclic_lr)
        scheduler.base_lrs = list(
            map(lambda group: 1.0, optimizer.param_groups))
    elif args.scheduler == 'step':
        step_lr = StepMinLr(start_epoch if args.reset_lr else 0,
                            init_lr=args.lr,
                            epochs_pro_decay=args.decay_step,
                            lr_decay_factor=args.decay_gamma,
                            min_lr=args.min_lr)

        scheduler = LambdaLR(optimizer, lr_lambda=step_lr)
        scheduler.base_lrs = list(
            map(lambda group: 1.0, optimizer.param_groups))
    else:
        assert False, 'wrong scheduler: ' + args.scheduler

    print 'scheduler.base_lrs=', scheduler.base_lrs
    logger.add_scalar('data/batch_size', args.batch_size, start_epoch)

    save_epoch = 50
    if not args.unsupervised:
        validate_epoch = 1
    else:
        validate_epoch = 50
        labels_holder = {
        }  # utility container to save labels from the previous clustering step

    last_lr = 100500
    for epoch in range(start_epoch, args.epochs):
        nmi_gt = None
        if epoch == start_epoch:
            if not args.unsupervised:
                validate(val_loader,
                         model,
                         criterion,
                         epoch - 1,
                         logger=logger)
            # elif start_epoch == 0:
            #     print 'validate_gt_linear'
            #     validate_gt_linear(train_loader_gt, val_loader_gt, num_gt_classes,
            #                        model, args.eval_layer, criterion, epoch - 1, lr=0.01,
            #                        num_train_epochs=2,
            #                        logger=logger, tag='val_gt_{}_{}'.format(args.eval_layer, eval_gt_aug))

        if args.unsupervised and (epoch == start_epoch
                                  or epoch % args.recluster_epoch == 0):
            train_loader, nmi_gt = unsupervised_clustering_step(
                epoch, model, is_sobel, args.sobel_normalized, split_dirs,
                dataset_indices, num_workers, labels_holder, logger,
                args.fast_dataflow)
            if args.reset_fc:
                model.module.reset_fc8()
            try:
                with open(join(args.output_dir, 'labels_holder.json'),
                          'w') as f:
                    for k in labels_holder.keys():
                        labels_holder[k] = np.asarray(
                            labels_holder[k]).tolist()
                    json.dump(labels_holder, f)
            except Exception as e:
                print e

        scheduler.step(epoch=epoch)
        if last_lr != scheduler.get_lr()[0]:
            last_lr = scheduler.get_lr()[0]
            print 'LR := {}'.format(last_lr)
        logger.add_scalar('data/lr', scheduler.get_lr()[0], epoch)
        logger.add_scalar('data/v', args.imagenet_version, epoch)
        logger.add_scalar('data/weight_decay', args.weight_decay, epoch)
        logger.add_scalar('data/dropout7_prob', args.dropout7_prob, epoch)

        top1_avg, top5_avg, loss_avg = \
            train(train_loader, model, criterion, optimizer,
                  epoch, args.epochs,
                  log_iter=100, logger=logger)

        if (epoch + 1) % validate_epoch == 0:
            # evaluate on validation set
            if not args.unsupervised:
                score = validate(val_loader,
                                 model,
                                 criterion,
                                 epoch,
                                 logger=logger)
            else:
                score = validate_gt_linear(
                    train_loader_gt,
                    val_loader_gt,
                    num_gt_classes,
                    model,
                    args.eval_layer,
                    criterion,
                    epoch,
                    lr=0.01,
                    num_train_epochs=args.epochs_train_linear,
                    logger=logger,
                    tag='val_gt_{}_{}'.format(args.eval_layer, eval_gt_aug))

            # remember best prec@1 and save checkpoint
            is_best = score > best_score
            best_score = max(score, best_score)
            best_ckpt_suffix = ''
        else:
            score = None
            if nmi_gt is not None and nmi_gt > best_nmi:
                best_nmi = nmi_gt
                best_ckpt_suffix = '_nmi'
                is_best = True
            else:
                is_best = False
                best_ckpt_suffix = ''

        if (epoch + 1) % save_epoch == 0:
            filepath = join(args.output_dir,
                            'checkpoint-{:05d}.pth.tar'.format(epoch + 1))
        else:
            filepath = join(args.output_dir, 'checkpoint.pth.tar')
        save_dict = {
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_score': best_score,
            'top1_avg_accuracy_train': top1_avg,
            'optimizer': optimizer.state_dict(),
        }
        if nmi_gt is not None:
            save_dict['nmi'] = nmi_gt
            save_dict['best_nmi'] = best_nmi
        if score is not None:
            save_dict['cur_score'] = score
        save_checkpoint(save_dict,
                        is_best=is_best,
                        filepath=filepath,
                        best_suffix=best_ckpt_suffix)
def train_one_fold(train_images: Any, train_scores: Any, val_images: Any,
                   val_scores: Any, fold: int) -> None:
    # create model
    logger.info("=> using pre-trained model '{}'".format(opt.MODEL.ARCH))
    if opt.MODEL.ARCH.startswith('resnet'):
        model = models.__dict__[opt.MODEL.ARCH](pretrained=True)
    else:
        model = pretrainedmodels.__dict__[opt.MODEL.ARCH](
            pretrained='imagenet')

    # for child in list(model.children())[:-1]:
    #     print("freezing layer:", child)
    #
    #     for param in child.parameters():
    #         param.requires_grad = False

    train_dataset = DataGenerator(train_images,
                                  train_scores,
                                  transform=transform)
    val_dataset = DataGenerator(val_images, val_scores, transform=transform)

    logger.info('{} samples in train dataset'.format(len(train_dataset)))
    logger.info('{} samples in validation dataset'.format(len(val_dataset)))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=opt.TRAIN.BATCH_SIZE,
                                               shuffle=opt.TRAIN.SHUFFLE,
                                               num_workers=opt.TRAIN.WORKERS)

    test_loader = torch.utils.data.DataLoader(val_dataset,
                                              batch_size=opt.TRAIN.BATCH_SIZE,
                                              shuffle=False,
                                              num_workers=opt.TRAIN.WORKERS)

    if opt.MODEL.ARCH.startswith('resnet'):
        assert (opt.MODEL.INPUT_SIZE % 32 == 0)
        model.avgpool = nn.AvgPool2d(opt.MODEL.INPUT_SIZE // 32, stride=1)
        model.fc = nn.Linear(model.fc.in_features * 5, opt.NUM_CLASSES)
        model = torch.nn.DataParallel(model).cuda()
    elif opt.MODEL.ARCH.startswith('se'):
        assert (opt.MODEL.INPUT_SIZE % 32 == 0)
        model.avgpool = nn.AvgPool2d(opt.MODEL.INPUT_SIZE // 32, stride=1)
        model.last_linear = nn.Linear(model.last_linear.in_features,
                                      opt.NUM_FEATURES)
        model = torch.nn.DataParallel(model).cuda()
    elif opt.MODEL.ARCH.startswith('se'):
        assert (opt.MODEL.INPUT_SIZE % 32 == 0)
        model.avgpool = nn.AvgPool2d(opt.MODEL.INPUT_SIZE // 32, stride=1)
        model.last_linear = nn.Linear(model.last_linear.in_features,
                                      opt.NUM_FEATURES)
        model = torch.nn.DataParallel(model).cuda()
    else:
        raise NotImplementedError

    params = filter(lambda p: p.requires_grad, model.module.parameters())
    optimizer = optim.Adam(params, opt.TRAIN.LEARNING_RATE)
    lr_scheduler = MultiStepLR(optimizer,
                               opt.TRAIN.LR_MILESTONES,
                               gamma=opt.TRAIN.LR_GAMMA,
                               last_epoch=-1)

    if opt.TRAIN.RESUME is None:
        last_epoch = 0
        logger.info("training will start from epoch {}".format(last_epoch + 1))
    else:
        last_checkpoint = torch.load(opt.TRAIN.RESUME)
        assert (last_checkpoint['arch'] == opt.MODEL.ARCH)
        model.module.load_state_dict(last_checkpoint['state_dict'])
        optimizer.load_state_dict(last_checkpoint['optimizer'])
        logger.info("checkpoint '{}' was loaded.".format(opt.TRAIN.RESUME))

        last_epoch = last_checkpoint['epoch']
        logger.info("training will be resumed from epoch {}".format(
            last_checkpoint['epoch']))

    criterion = nn.CrossEntropyLoss(class_weights)
    #     last_fc = nn.Linear(in_features=opt.NUM_FEATURES, out_features=1).cuda()
    #     relu = nn.ReLU(inplace=True).cuda()

    best_map3 = 0.0
    best_epoch = 0

    train_losses: List[float] = []
    train_metrics: List[float] = []
    test_losses: List[float] = []
    test_metrics: List[float] = []

    for epoch in range(last_epoch + 1, opt.TRAIN.EPOCHS + 1):
        logger.info('-' * 50)
        lr_scheduler.step(epoch)
        logger.info('lr: {}'.format(lr_scheduler.get_lr()))

        train(train_loader, model, criterion, optimizer, epoch, train_losses,
              train_metrics)
        map3 = validate(test_loader, model, criterion, test_losses,
                        test_metrics)
        is_best = map3 > best_map3
        if is_best:
            best_epoch = epoch
            best_map3 = map3

        if epoch % opt.TRAIN.SAVE_FREQ == 0:
            save_checkpoint(
                {
                    'epoch': epoch,
                    'arch': opt.MODEL.ARCH,
                    'state_dict': model.module.state_dict(),
                    'best_map3': best_map3,
                    'map3': map3,
                    'optimizer': optimizer.state_dict(),
                }, '{}_[{}]_{:.04f}_fold{}.pk'.format(opt.MODEL.ARCH, epoch,
                                                      map3, fold))

        if is_best:
            save_checkpoint(
                {
                    'epoch': epoch,
                    'arch': opt.MODEL.ARCH,
                    'state_dict': model.module.state_dict(),
                    'best_map3': best_map3,
                    'map3': map3,
                    'optimizer': optimizer.state_dict(),
                }, 'best_model_fold{}.pk'.format(fold))

    logger.info('best MAP@3: {:.04f}'.format(best_map3))

    #best_checkpoint_path = osp.join(opt.EXPERIMENT.DIR, 'best_model.pk')
    #logger.info("Loading parameters from the best checkpoint '{}',".format(best_checkpoint_path))
    #checkpoint = torch.load(best_checkpoint_path)
    #logger.info("which has a single crop map3 {:.02f}%.".format(checkpoint['map3']))
    #model.load_state_dict(checkpoint['state_dict'])

    best_epoch = np.argmin(test_losses)
    best_loss = test_losses[best_epoch]
    plt.figure(0)
    x = np.arange(last_epoch + 1, opt.TRAIN.EPOCHS + 1)
    plt.plot(x, train_losses, '-+')
    plt.plot(x, test_losses, '-+')
    plt.scatter(best_epoch + 1, best_loss, c='C1', marker='^', s=80)
    # plt.ylim(ymin=0, ymax=5)
    plt.grid(linestyle=':')
    plt.xlabel('epoch')
    plt.ylabel('MAP@3')
    plt.title('loss over epoch')
    plt.savefig(osp.join(opt.EXPERIMENT.DIR, 'loss_curves_%d.png' % fold))

    best_epoch = np.argmin(test_metrics)
    best_metrics = test_metrics[best_epoch]
    plt.figure(1)
    plt.plot(x, train_metrics, '-+')
    plt.plot(x, test_metrics, '-+')
    plt.scatter(best_epoch + 1, best_metrics, c='C1', marker='^', s=80)
    # plt.ylim(ymin=0, ymax=100)
    plt.grid(linestyle=':')
    plt.xlabel('epoch')
    plt.ylabel('map3')
    plt.title('map3 over epoch')
    plt.savefig(osp.join(opt.EXPERIMENT.DIR, 'accuracy_curves_%d.png' % fold))
Example #14
0
class TrainerLaneDetector:
    def __init__(self, network, network_params, training_params):
        """
        Initialize
        """
        super(TrainerLaneDetector, self).__init__()

        self.network_params = network_params
        self.training_params = training_params

        self.lane_detection_network = network

        self.optimizer = torch.optim.Adam(
            self.lane_detection_network.parameters(),
            lr=training_params.l_rate,
            weight_decay=training_params.weight_decay)
        self.lr_scheduler = MultiStepLR(
            self.optimizer,
            milestones=training_params.scheduler_milestones,
            gamma=training_params.scheduler_gamma)
        self.current_epoch = 0
        self.current_step = 0
        self.current_loss = None

    def train(self, batch_sample, epoch, step):
        """ train
        :param
            inputs -> ndarray [#batch, 3, 256, 512]
            target_lanes -> [[4, 48],...,], len(List[ndarray]) = 8, ndarray -> [lanes=4, sample_pts=48]
            target_h -> [[4, 48],...,], len(List[ndarray]) = 8, ndarray -> [lanes=4, sample_pts=48]

        compute loss function and optimize
        """
        grid_x = self.network_params.grid_x
        grid_y = self.network_params.grid_y
        feature_size = self.network_params.feature_size

        real_batch_size = batch_sample["image"].shape[0]

        # generate ground truth
        ground_truth_point = batch_sample["detection_gt"]
        ground_truth_instance = batch_sample["instance_gt"]

        # convert numpy array to torch tensor
        ground_truth_point = torch.from_numpy(ground_truth_point).float()
        ground_truth_point = Variable(ground_truth_point).cuda()
        ground_truth_point.requires_grad = False

        ground_truth_instance = torch.from_numpy(ground_truth_instance).float()
        ground_truth_instance = Variable(ground_truth_instance).cuda()
        ground_truth_instance.requires_grad = False

        # inference lane_detection_network
        result = self.predict(batch_sample["image"])

        metrics = {}
        lane_detection_loss = 0
        for hourglass_id, intermediate_loss in enumerate(result):
            confidance, offset, feature = intermediate_loss
            # e.g.
            # confidence shape = [8, 1, 32, 64]
            # offset shape = [8, 3, 32, 64]
            # feature shape = [8, 4, 32, 64]  for instance segmentation
            # quote
            # "The feature size is set to 4, and this size is observed to have no major effect for the performance."

            # compute loss for point prediction
            offset_loss = 0
            exist_condidence_loss = 0
            nonexist_confidence_loss = 0

            # exist confidance loss
            confidance_gt = ground_truth_point[:, 0, :, :]  # [8,1,32,64]
            confidance_gt = confidance_gt.view(real_batch_size, 1, grid_y,
                                               grid_x)  # [8,1,32,64]
            exist_condidence_loss = torch.sum(
                (confidance_gt[confidance_gt == 1] -
                 confidance[confidance_gt == 1])**
                2) / torch.sum(confidance_gt == 1)

            # non exist confidance loss
            nonexist_confidence_loss = torch.sum(
                (confidance_gt[confidance_gt == 0] -
                 confidance[confidance_gt == 0])**
                2) / torch.sum(confidance_gt == 0)

            # offset loss
            offset_x_gt = ground_truth_point[:, 1:2, :, :]
            offset_y_gt = ground_truth_point[:, 2:3, :, :]

            predict_x = offset[:, 0:1, :, :]
            predict_y = offset[:, 1:2, :, :]

            x_offset_loss = torch.sum((offset_x_gt[confidance_gt == 1] -
                                       predict_x[confidance_gt == 1])**
                                      2) / torch.sum(confidance_gt == 1)
            y_offset_loss = torch.sum((offset_y_gt[confidance_gt == 1] -
                                       predict_y[confidance_gt == 1])**
                                      2) / torch.sum(confidance_gt == 1)

            offset_loss = (x_offset_loss + y_offset_loss) / 2

            # compute loss for similarity
            sisc_loss = 0
            disc_loss = 0

            feature_map = feature.view(real_batch_size, feature_size, 1,
                                       grid_y * grid_x)  # [8, 4, 1, 2048]
            feature_map = feature_map.expand(
                real_batch_size, feature_size, grid_y * grid_x,
                grid_y * grid_x).detach()  # [8, 4, 2048, 2048]

            point_feature = feature.view(real_batch_size, feature_size,
                                         grid_y * grid_x, 1)  # [8, 4, 2048, 1]
            point_feature = point_feature.expand(
                real_batch_size, feature_size, grid_y * grid_x,
                grid_y * grid_x)  # .detach()  [8, 4, 2048, 2048]

            distance_map = (feature_map - point_feature)**2
            distance_map = torch.norm(distance_map,
                                      dim=1).view(real_batch_size, 1,
                                                  grid_y * grid_x,
                                                  grid_y * grid_x)

            # same instance
            sisc_loss = torch.sum(
                distance_map[ground_truth_instance == 1]) / torch.sum(
                    ground_truth_instance == 1)

            # different instance, same class
            disc_loss = self.training_params.K1 - distance_map[
                ground_truth_instance ==
                2]  # self.p.K1/distance_map[ground_truth_instance==2] + (self.p.K1-distance_map[ground_truth_instance==2])
            disc_loss[disc_loss < 0] = 0
            disc_loss = torch.sum(disc_loss) / torch.sum(
                ground_truth_instance == 2)

            lane_loss = self.training_params.constant_exist * exist_condidence_loss + self.training_params.constant_nonexist * nonexist_confidence_loss + self.training_params.constant_offset * offset_loss
            instance_loss = self.training_params.constant_alpha * sisc_loss + self.training_params.constant_beta * disc_loss
            lane_detection_loss = lane_detection_loss + self.training_params.constant_lane_loss * lane_loss + self.training_params.constant_instance_loss * instance_loss

            metrics["hourglass_" + str(hourglass_id) +
                    "_same_instance_same_class_loss"] = sisc_loss.item()
            metrics["hourglass_" + str(hourglass_id) +
                    "_diff_instance_same_class_loss"] = disc_loss.item()
            metrics["hourglass_" + str(hourglass_id) +
                    "_instance_loss"] = instance_loss.item()
            metrics["hourglass_" + str(
                hourglass_id
            ) + "_confidence_loss"] = self.training_params.constant_exist * exist_condidence_loss.item(
            ) + self.training_params.constant_nonexist * nonexist_confidence_loss.item(
            )
            metrics[
                "hourglass_" + str(hourglass_id) +
                "_offset_loss"] = self.training_params.constant_offset * offset_loss.item(
                )
            metrics["hourglass_" + str(
                hourglass_id
            ) + "_total_loss"] = self.training_params.constant_lane_loss * lane_loss.item(
            ) + self.training_params.constant_instance_loss * instance_loss.item(
            )

        metrics["pinet_total_loss"] = lane_detection_loss.item()

        self.optimizer.zero_grad()
        lane_detection_loss.backward()
        self.optimizer.step()

        del confidance, offset, feature
        del ground_truth_point, ground_truth_instance
        del feature_map, point_feature, distance_map
        del exist_condidence_loss, nonexist_confidence_loss, offset_loss, sisc_loss, disc_loss, lane_loss, instance_loss

        # update lr based on epoch
        if epoch != self.current_epoch:
            self.current_epoch = epoch
            self.lr_scheduler.step()

        if step != self.current_step:
            self.current_step = step

        self.current_loss = lane_detection_loss.item()

        return self.current_loss, metrics, result

    def predict(self, inputs: np.ndarray):
        """
        predict lanes

        :param inputs -> [batch_size, 3, 256, 512]
        :return:
        """
        inputs = torch.from_numpy(inputs).float()
        inputs = Variable(inputs).cuda()
        return self.lane_detection_network(inputs)

    def test_on_image(
        self,
        test_images: np.ndarray,
        threshold_confidence: float = 0.81
    ) -> Tuple[List[List[int]], List[List[int]], List[np.ndarray]]:
        """ predict, then post-process

        :param test_images: input image or image batch
        :param threshold_confidence, if confidence of detected key points greater than threshold, then will be accepted
        """
        rank = len(test_images.shape)
        if rank == 3:
            batch_image = np.expand_dims(test_images, 0)
        elif rank == 4:
            batch_image = test_images
        else:
            raise IndexError

        # start = time.time()
        result = self.predict(batch_image)  # accept rank = 4 only
        # end = time.time()
        # print(f"predict time: {end - start} [sec]")  # [second]

        confidences, offsets, instances = result[
            -1]  # trainer use different output, compared with inference

        num_batch = batch_image.shape[0]

        out_x = []
        out_y = []
        out_images = []

        for i in range(num_batch):
            # test on test data set
            image = deepcopy(batch_image[i])
            image = np.rollaxis(image, axis=2, start=0)
            image = np.rollaxis(image, axis=2, start=0) * 255.0
            image = image.astype(np.uint8).copy()

            confidence = confidences[i].view(
                self.network_params.grid_y,
                self.network_params.grid_x).cpu().data.numpy()

            offset = offsets[i].cpu().data.numpy()
            offset = np.rollaxis(offset, axis=2, start=0)
            offset = np.rollaxis(offset, axis=2, start=0)

            instance = instances[i].cpu().data.numpy()
            instance = np.rollaxis(instance, axis=2, start=0)
            instance = np.rollaxis(instance, axis=2, start=0)

            # generate point and cluster
            raw_x, raw_y = generate_result(confidence, offset, instance,
                                           threshold_confidence)

            # eliminate fewer points
            in_x, in_y = eliminate_fewer_points(raw_x, raw_y)

            # sort points along y
            in_x, in_y = sort_along_y(in_x, in_y)
            in_x, in_y = eliminate_out(in_x, in_y)
            in_x, in_y = sort_along_y(in_x, in_y)
            in_x, in_y = eliminate_fewer_points(in_x, in_y)

            result_image = draw_points(in_x, in_y, deepcopy(image))

            out_x.append(in_x)
            out_y.append(in_y)
            out_images.append(result_image)

        return out_x, out_y, out_images

    def training_mode(self):
        """ Training mode """
        self.lane_detection_network.train()

    def evaluate_mode(self):
        """ evaluate(test mode) """
        self.lane_detection_network.eval()

    def load_weights(self, path):
        self.lane_detection_network.load_state_dict(torch.load(path), False)

    def load_weights_v2(self, path):
        checkpoint = torch.load(path)
        self.lane_detection_network.load_state_dict(
            checkpoint['model_state_dict'], strict=False)
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.current_epoch = checkpoint['epoch']
        self.current_step = checkpoint['step']
        self.current_loss = checkpoint['loss']

    def save_model(self, path):
        """ Save model """
        torch.save(self.lane_detection_network.state_dict(), path)

    def save_model_v2(self, path):
        """ Save model """
        torch.save(
            {
                'epoch': self.current_epoch,
                'step': self.current_step,
                'model_state_dict': self.lane_detection_network.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'loss': self.current_loss,
            }, path)

    def get_lr(self):
        return self.lr_scheduler.get_lr()

    @staticmethod
    def count_parameters(model: [nn.Module]):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
Example #15
0
def train_net(train_loader,
              valid_loader,
              class_labels,
              lab_results_dir,
              learning_rate=0.0001,
              is_lr_scheduled=True,
              max_epoch=1,
              save_epochs=[10, 20, 30]):
    # Measure execution time
    train_start = time.time()
    start_time = strftime('SSD__%dth_%H:%M_', gmtime())

    # Define the Net
    print('num_class: ', len(class_labels))
    print('class_labels: ', class_labels)
    ssd_net = SSD(len(class_labels))
    # Set the parameter defined in the net to GPU
    net = ssd_net

    if torch.cuda.is_available():
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.backends.cudnn.benchmark = True
        net.cuda()

    # Define the loss
    center_var = 0.1
    size_var = 0.2
    criterion = MultiboxLoss([center_var, center_var, size_var, size_var],
                             iou_threshold=0.5,
                             neg_pos_ratio=3.0)

    # Define Optimizer
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
    # optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9,
    #                       weight_decay=0.0005)
    if is_lr_scheduled:
        scheduler = MultiStepLR(optimizer,
                                milestones=[10, 30, 50, 70],
                                gamma=0.1)

    # Train data
    conf_losses = []
    loc_losses = []
    v_conf_losses = []
    v_loc_losses = []
    itr = 0
    train_log = []
    valid_log = []
    for epoch_idx in range(0, max_epoch):

        # decrease learning rate
        if is_lr_scheduled:
            scheduler.step()
            print('\n\n===> lr: {}'.format(scheduler.get_lr()[0]))

        # Save the trained network
        if epoch_idx in save_epochs:
            temp_file = start_time + 'epoch_{}'.format(epoch_idx)
            net_state = net.state_dict()  # serialize the instance
            torch.save(net_state, lab_results_dir + temp_file + '__model.pth')
            print('================> Temp file is created: ',
                  lab_results_dir + temp_file + '__model.pth')

        # iterate the mini-batches:
        for train_batch_idx, data in enumerate(train_loader):
            train_images, train_labels, train_bboxes, prior_bbox = data

            # Switch to train model
            net.train()

            # Forward
            train_img = Variable(train_images.clone().cuda())
            train_bbox = Variable(train_bboxes.clone().cuda())
            train_label = Variable(train_labels.clone().cuda())

            train_out_confs, train_out_locs = net.forward(train_img)
            # locations(feature map base) -> bbox(center form)
            train_out_bbox = loc2bbox(train_out_locs,
                                      prior_bbox[0].unsqueeze(0))

            # update the parameter gradients as zero
            optimizer.zero_grad()

            # Compute the loss
            conf_loss, loc_loss = criterion.forward(train_out_confs,
                                                    train_out_bbox,
                                                    train_label, train_bbox)
            train_loss = conf_loss + loc_loss

            # Do the backward to compute the gradient flow
            train_loss.backward()

            # Update the parameters
            optimizer.step()

            conf_losses.append((itr, conf_loss))
            loc_losses.append((itr, loc_loss))

            itr += 1
            if train_batch_idx % 20 == 0:
                train_log_temp = '[Train]epoch: %d itr: %d Conf Loss: %.4f Loc Loss: %.4f' % (
                    epoch_idx, itr, conf_loss, loc_loss)
                train_log += (train_log_temp + '\n')
                print(train_log_temp)
                if False:  # check input tensor
                    image_s = train_images[0, :, :, :].cpu().numpy().astype(
                        np.float32).transpose().copy()  # c , h, w -> h, w, c
                    image_s = ((image_s + 1) / 2)
                    bbox_cr_s = torch.cat([
                        train_bboxes[..., :2] - train_bboxes[..., 2:] / 2,
                        train_bboxes[..., :2] + train_bboxes[..., 2:] / 2
                    ],
                                          dim=-1)
                    bbox_prior_s = bbox_cr_s[0, :].cpu().numpy().astype(
                        np.float32).reshape(
                            (-1, 4)).copy()  # First sample in batch
                    bbox_prior_s = (bbox_prior_s * 300)
                    label_prior_s = train_labels[0, :].cpu().numpy().astype(
                        np.float32).copy()
                    bbox_s = bbox_prior_s[label_prior_s > 0]
                    label_s = (label_prior_s[label_prior_s > 0]).astype(
                        np.uint8)

                    for idx in range(0, len(label_s)):
                        cv2.rectangle(image_s,
                                      (bbox_s[idx][0], bbox_s[idx][1]),
                                      (bbox_s[idx][2], bbox_s[idx][3]),
                                      (255, 0, 0), 2)
                        cv2.putText(image_s, class_labels[label_s[idx]],
                                    (bbox_s[idx][0], bbox_s[idx][1]),
                                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, (200, 0, 0),
                                    1, cv2.LINE_AA)

                    plt.imshow(image_s)
                    plt.show()

            # validaton
            if train_batch_idx % 200 == 0:
                net.eval()  # Evaluation mode
                v_conf_subsum = torch.zeros(
                    1)  # collect the validation losses for avg.
                v_loc_subsum = torch.zeros(1)
                v_itr_max = 5
                for valid_itr, data in enumerate(valid_loader):
                    valid_image, valid_label, valid_bbox, prior_bbox = data

                    valid_image = Variable(valid_image.cuda())
                    valid_bbox = Variable(valid_bbox.cuda())
                    valid_label = Variable(valid_label.cuda())

                    # Forward and compute loss
                    with torch.no_grad(
                    ):  # make all grad flags to false!! ( Memory decrease)
                        valid_out_confs, valid_out_locs = net.forward(
                            valid_image)
                    valid_out_bbox = loc2bbox(
                        valid_out_locs,
                        prior_bbox[0].unsqueeze(0))  # loc -> bbox(center form)

                    valid_conf_loss, valid_loc_loss = criterion.forward(
                        valid_out_confs, valid_out_bbox, valid_label,
                        valid_bbox)

                    v_conf_subsum += valid_conf_loss
                    v_loc_subsum += valid_loc_loss
                    valid_itr += 1
                    if valid_itr > v_itr_max:
                        break

                # avg. valid loss
                v_conf_losses.append((itr, v_conf_subsum / v_itr_max))
                v_loc_losses.append((itr, v_loc_subsum / v_itr_max))

                valid_log_temp = '==>[Valid]epoch: %d itr: %d Conf Loss: %.4f Loc Loss: %.4f' % (
                    epoch_idx, itr, v_conf_subsum / v_itr_max,
                    v_loc_subsum / v_itr_max)
                valid_log += (valid_log_temp + '\n')
                print(valid_log_temp)

    # Measure the time
    train_end = time.time()
    m, s = divmod(train_end - train_start, 60)
    h, m = divmod(m, 60)

    # Save the result
    results_file_name = start_time + 'itr_{}'.format(itr)

    train_data = {
        'conf_losses': np.asarray(conf_losses),
        'loc_losses': np.asarray(loc_losses),
        'v_conf_losses': np.asarray(v_conf_losses),
        'v_loc_losses': np.asarray(v_loc_losses),
        'learning_rate': learning_rate,
        'total_itr': itr,
        'max_epoch': max_epoch,
        'train_time': '%d:%02d:%02d' % (h, m, s)
    }

    torch.save(train_data, lab_results_dir + results_file_name + '.loss')

    # Save the trained network
    net_state = net.state_dict()  # serialize the instance
    torch.save(net_state, lab_results_dir + results_file_name + '__model.pth')

    # Save the train/valid log
    torch.save({'log': train_log},
               lab_results_dir + results_file_name + '__train.log')
    torch.save({'log': valid_log},
               lab_results_dir + results_file_name + '__valid.log')

    return lab_results_dir + results_file_name
Example #16
0
#begin train process
if __name__ == '__main__':
    # model selection
    print('===> Building model')
    # criterion = sum_squared_error()

    optimizer = optim.Adam(
        net.parameters(),
        lr=LEARNING_RATE,
    )
    scheduler = MultiStepLR(optimizer, milestones=[5, 35, 50], gamma=0.25)

    for epoch in range(EPOCH):  #
        scheduler.step(epoch)
        print("Decaying learning rate to %g" % scheduler.get_lr()[0])
        for tex in range(1):
            mode = np.random.randint(0, 4)
            net.train()

            train2 = data_aug(train2, mode)
            train3 = data_aug(train3, mode)

            print('epochs:', epoch)

            channels2 = 72  # 93 channels
            channels3 = 72  # 191 channels

            data_patches2, data_cubic_patches2 = datagenerator(
                train2, channels2)
            data_patches3, data_cubic_patches3 = datagenerator(
Example #17
0
def main():
    if not torch.cuda.is_available():
        raise Exception("need gpu to train network!")

    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    cudnn.benchmark = True
    cudnn.enabled = True

    logger = get_logger(__name__, Config.log)

    Config.gpus = torch.cuda.device_count()
    logger.info("use {} gpus".format(Config.gpus))
    config = {
        key: value
        for key, value in Config.__dict__.items() if not key.startswith("__")
    }
    logger.info(f"args: {config}")

    start_time = time.time()

    # dataset and dataloader
    logger.info("start loading data")

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    train_dataset = ImageFolder(Config.train_dataset_path, train_transform)
    train_loader = DataLoader(
        train_dataset,
        batch_size=Config.batch_size,
        shuffle=True,
        num_workers=Config.num_workers,
        pin_memory=True,
    )
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    val_dataset = ImageFolder(Config.val_dataset_path, val_transform)
    val_loader = DataLoader(
        val_dataset,
        batch_size=Config.batch_size,
        num_workers=Config.num_workers,
        pin_memory=True,
    )
    logger.info("finish loading data")

    # network
    net = ChannelDistillResNet1834(Config.num_classes, Config.dataset_type)
    net = nn.DataParallel(net).cuda()

    # loss and optimizer
    criterion = []
    for loss_item in Config.loss_list:
        loss_name = loss_item["loss_name"]
        loss_type = loss_item["loss_type"]
        if "kd" in loss_type:
            criterion.append(losses.__dict__[loss_name](loss_item["T"]).cuda())
        else:
            criterion.append(losses.__dict__[loss_name]().cuda())

    optimizer = SGD(net.parameters(),
                    lr=Config.lr,
                    momentum=0.9,
                    weight_decay=1e-4)
    scheduler = MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.1)

    # only evaluate
    if Config.evaluate:
        # load best model
        if not os.path.isfile(Config.evaluate):
            raise Exception(
                f"{Config.evaluate} is not a file, please check it again")
        logger.info("start evaluating")
        logger.info(f"start resuming model from {Config.evaluate}")
        checkpoint = torch.load(Config.evaluate,
                                map_location=torch.device("cpu"))
        net.load_state_dict(checkpoint["model_state_dict"])
        prec1, prec5 = validate(val_loader, net)
        logger.info(
            f"epoch {checkpoint['epoch']:0>3d}, top1 acc: {prec1:.2f}%, top5 acc: {prec5:.2f}%"
        )
        return

    start_epoch = 1
    # resume training
    if os.path.exists(Config.resume):
        logger.info(f"start resuming model from {Config.resume}")
        checkpoint = torch.load(Config.resume,
                                map_location=torch.device("cpu"))
        start_epoch += checkpoint["epoch"]
        net.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        logger.info(
            f"finish resuming model from {Config.resume}, epoch {checkpoint['epoch']}, "
            f"loss: {checkpoint['loss']:3f}, lr: {checkpoint['lr']:.6f}, "
            f"top1_acc: {checkpoint['acc']}%, loss {checkpoint['loss']}%")

    if not os.path.exists(Config.checkpoints):
        os.makedirs(Config.checkpoints)

    logger.info("start training")
    best_acc = 0.
    for epoch in range(start_epoch, Config.epochs + 1):
        prec1, prec5, loss = train(train_loader, net, criterion, optimizer,
                                   scheduler, epoch, logger)
        logger.info(
            f"train: epoch {epoch:0>3d}, top1 acc: {prec1:.2f}%, top5 acc: {prec5:.2f}%"
        )

        prec1, prec5 = validate(val_loader, net)
        logger.info(
            f"val: epoch {epoch:0>3d}, top1 acc: {prec1:.2f}%, top5 acc: {prec5:.2f}%"
        )

        # remember best prec@1 and save checkpoint
        torch.save(
            {
                "epoch": epoch,
                "acc": prec1,
                "loss": loss,
                "lr": scheduler.get_lr()[0],
                "model_state_dict": net.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
            }, os.path.join(Config.checkpoints, "latest.pth"))
        if prec1 > best_acc:
            shutil.copyfile(os.path.join(Config.checkpoints, "latest.pth"),
                            os.path.join(Config.checkpoints, "best.pth"))
            best_acc = prec1

    training_time = (time.time() - start_time) / 3600
    logger.info(
        f"finish training, best acc: {best_acc:.2f}%, total training time: {training_time:.2f} hours"
    )
Example #18
0
def train(**args):
    """
    Evaluate selected model 
    Args:
        rerun        (Int):        Integer indicating number of repetitions for the select experiment 
        seed         (Int):        Integer indicating set seed for random state
        save_dir     (String):     Top level directory to generate results folder
        model        (String):     Name of selected model 
        dataset      (String):     Name of selected dataset  
        exp          (String):     Name of experiment 
        debug        (Int):        Debug state to avoid saving variables 
        load_type    (String):     Keyword indicator to evaluate the testing or validation set
        pretrained   (Int/String): Int/String indicating loading of random, pretrained or saved weights
        opt          (String):     Int/String indicating loading of random, pretrained or saved weights
        lr           (Float):      Learning rate 
        momentum     (Float):      Momentum in optimizer 
        weight_decay (Float):      Weight_decay value 
        final_shape  ([Int, Int]): Shape of data when passed into network
        
    Return:
        None
    """

    print(
        "\n############################################################################\n"
    )
    print("Experimental Setup: ", args)
    print(
        "\n############################################################################\n"
    )

    for total_iteration in range(args['rerun']):

        # Generate Results Directory
        d = datetime.datetime.today()
        date = d.strftime('%Y%m%d-%H%M%S')
        result_dir = os.path.join(
            args['save_dir'], args['model'], '_'.join(
                (args['dataset'], args['exp'], date)))
        log_dir = os.path.join(result_dir, 'logs')
        save_dir = os.path.join(result_dir, 'checkpoints')

        if not args['debug']:
            os.makedirs(result_dir, exist_ok=True)
            os.makedirs(log_dir, exist_ok=True)
            os.makedirs(save_dir, exist_ok=True)

            # Save copy of config file
            with open(os.path.join(result_dir, 'config.yaml'), 'w') as outfile:
                yaml.dump(args, outfile, default_flow_style=False)

            # Tensorboard Element
            writer = SummaryWriter(log_dir)

        # Check if GPU is available (CUDA)
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        # Load Network
        model = create_model_object(**args).to(device)

        # Load Data
        loader = data_loader(model_obj=model, **args)

        if args['load_type'] == 'train':
            train_loader = loader['train']
            valid_loader = loader[
                'train']  # Run accuracy on train data if only `train` selected

        elif args['load_type'] == 'train_val':
            train_loader = loader['train']
            valid_loader = loader['valid']

        else:
            sys.exit('Invalid environment selection for training, exiting')

        # END IF

        # Training Setup
        params = [p for p in model.parameters() if p.requires_grad]

        if args['opt'] == 'sgd':
            optimizer = optim.SGD(params,
                                  lr=args['lr'],
                                  momentum=args['momentum'],
                                  weight_decay=args['weight_decay'])

        elif args['opt'] == 'adam':
            optimizer = optim.Adam(params,
                                   lr=args['lr'],
                                   weight_decay=args['weight_decay'])

        else:
            sys.exit('Unsupported optimizer selected. Exiting')

        # END IF

        scheduler = MultiStepLR(optimizer,
                                milestones=args['milestones'],
                                gamma=args['gamma'])

        if isinstance(args['pretrained'], str):
            ckpt = load_checkpoint(args['pretrained'])
            model.load_state_dict(ckpt)
            start_epoch = load_checkpoint(args['pretrained'],
                                          key_name='epoch') + 1
            optimizer.load_state_dict(
                load_checkpoint(args['pretrained'], key_name='optimizer'))

            for quick_looper in range(start_epoch):
                scheduler.step()

            # END FOR

        else:
            start_epoch = 0

        # END IF

        model_loss = Losses(device=device, **args)
        acc_metric = Metrics(**args)
        best_val_acc = 0.0

        ############################################################################################################################################################################

        # Start: Training Loop
        for epoch in range(start_epoch, args['epoch']):
            running_loss = 0.0
            print('Epoch: ', epoch)

            # Setup Model To Train
            model.train()

            # Start: Epoch
            for step, data in enumerate(train_loader):
                if step % args['pseudo_batch_loop'] == 0:
                    loss = 0.0
                    optimizer.zero_grad()

                # END IF

                x_input = data['data'].to(device)
                annotations = data['annots']

                assert args['final_shape'] == list(x_input.size(
                )[-2:]), "Input to model does not match final_shape argument"
                outputs = model(x_input)
                loss = model_loss.loss(outputs, annotations)
                loss = loss * args['batch_size']
                loss.backward()

                running_loss += loss.item()

                if np.isnan(running_loss):
                    import pdb
                    pdb.set_trace()

                # END IF

                if not args['debug']:
                    # Add Learning Rate Element
                    for param_group in optimizer.param_groups:
                        writer.add_scalar(
                            args['dataset'] + '/' + args['model'] +
                            '/learning_rate', param_group['lr'],
                            epoch * len(train_loader) + step)

                    # END FOR

                    # Add Loss Element
                    writer.add_scalar(
                        args['dataset'] + '/' + args['model'] +
                        '/minibatch_loss',
                        loss.item() / args['batch_size'],
                        epoch * len(train_loader) + step)

                # END IF

                if ((epoch * len(train_loader) + step + 1) % 100 == 0):
                    print('Epoch: {}/{}, step: {}/{} | train loss: {:.4f}'.
                          format(
                              epoch, args['epoch'], step + 1,
                              len(train_loader), running_loss /
                              float(step + 1) / args['batch_size']))

                # END IF

                if (epoch * len(train_loader) +
                    (step + 1)) % args['pseudo_batch_loop'] == 0 and step > 0:
                    # Apply large mini-batch normalization
                    for param in model.parameters():
                        param.grad *= 1. / float(
                            args['pseudo_batch_loop'] * args['batch_size'])
                    optimizer.step()

                # END IF

            # END FOR: Epoch

            if not args['debug']:
                # Save Current Model
                save_path = os.path.join(
                    save_dir, args['dataset'] + '_epoch' + str(epoch) + '.pkl')
                save_checkpoint(epoch, step, model, optimizer, save_path)

            # END IF: Debug

            scheduler.step(epoch=epoch)
            print('Schedulers lr: %f', scheduler.get_lr()[0])

            ## START FOR: Validation Accuracy
            running_acc = []
            running_acc = valid(valid_loader, running_acc, model, device,
                                acc_metric)
            if not args['debug']:
                writer.add_scalar(
                    args['dataset'] + '/' + args['model'] +
                    '/validation_accuracy', 100. * running_acc[-1],
                    epoch * len(valid_loader) + step)
            print('Accuracy of the network on the validation set: %f %%\n' %
                  (100. * running_acc[-1]))

            # Save Best Validation Accuracy Model Separately
            if best_val_acc < running_acc[-1]:
                best_val_acc = running_acc[-1]

                if not args['debug']:
                    # Save Current Model
                    save_path = os.path.join(
                        save_dir, args['dataset'] + '_best_model.pkl')
                    save_checkpoint(epoch, step, model, optimizer, save_path)

                # END IF

            # END IF

        # END FOR: Training Loop

    ############################################################################################################################################################################

        if not args['debug']:
            # Close Tensorboard Element
            writer.close()
Example #19
0
def train(data_path, models_path, backend, snapshot, crop_x, crop_y,
          batch_size, alpha, epochs, start_lr, milestones, gpu):
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    net, starting_epoch = build_network(snapshot, backend)
    data_path = os.path.abspath(os.path.expanduser(data_path))
    models_path = os.path.abspath(os.path.expanduser(models_path))
    os.makedirs(models_path, exist_ok=True)
    '''
        To follow this training routine you need a DataLoader that yields the tuples of the following format:
        (Bx3xHxW FloatTensor x, BxHxW LongTensor y, BxN LongTensor y_cls) where
        x - batch of input images,
        y - batch of groung truth seg maps,
        y_cls - batch of 1D tensors of dimensionality N: N total number of classes, 
        y_cls[i, T] = 1 if class T is present in image i, 0 otherwise
    '''

    voc_data = pascalVOCLoader(root=data_path,
                               is_transform=True,
                               augmentations=None)
    # train_loader, class_weights, n_images = None, None, None
    train_loader = DataLoader(voc_data,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=0)
    max_steps = len(voc_data)
    class_weights = None

    optimizer = optim.Adam(net.parameters(), lr=start_lr)
    scheduler = MultiStepLR(optimizer,
                            milestones=[int(x) for x in milestones.split(',')],
                            gamma=0.1)
    running_score = runningScore(21)
    for epoch in range(starting_epoch, starting_epoch + epochs):
        seg_criterion = nn.NLLLoss2d(weight=class_weights)
        cls_criterion = nn.BCEWithLogitsLoss(weight=class_weights)
        epoch_losses = []
        # train_iterator = tqdm(train_loader, total=max_steps // batch_size + 1)
        net.train()
        print('------------epoch[{}]----------'.format(epoch + 1))
        for i, (x, y, y_cls) in enumerate(train_loader):
            optimizer.zero_grad()
            x, y, y_cls = Variable(x).cuda(), Variable(y).cuda(), Variable(
                y_cls).float().cuda()
            out, out_cls = net(x)
            pred = out.data.max(1)[1].cpu().numpy()
            seg_loss, cls_loss = seg_criterion(out, y), cls_criterion(
                out_cls, y_cls)
            loss = seg_loss + alpha * cls_loss
            epoch_losses.append(loss.item())
            running_score.update(y.data.cpu().numpy(), pred)
            if (i + 1) % 138 == 0:
                score, class_iou = running_score.get_scores()
                for k, v in score.items():
                    print(k, v)
                    logger.info('{}:{}'.format(k, v))
                running_score.reset()
            print_format_str = "Epoch[{}] batch[{}] loss = {:.4f} LR = {}"
            print_str = print_format_str.format(epoch + 1, i + 1, loss.item(),
                                                scheduler.get_lr()[0])
            print(print_str)
            logger.info(print_str)
            '''
            status = '[{}] loss = {:.4f} avg = {:.4f}, LR = {}'.format(
                epoch + 1, loss.item(), np.mean(epoch_losses), scheduler.get_lr()[0])
            train_iterator.set_description(status)
            '''
            loss.backward()
            optimizer.step()

        scheduler.step()
        if epoch + 1 > 20:
            train_loss = ('%.4f' % np.mean(epoch_losses))
            torch.save(
                net.state_dict(),
                os.path.join(
                    models_path,
                    '_'.join(["PSPNet", str(epoch + 1), train_loss]) + '.pth'))
Example #20
0
        net = utilities.init_model()
        print(net)
        optimizer = utilities.init_optimizer(net.parameters())
        print(optimizer)

        print('Learning rate drop schedule: ', kvs['args'].lr_drop)

        scheduler = MultiStepLR(optimizer, milestones=kvs['args'].lr_drop, gamma=kvs['args'].learning_rate_decay)

        train_losses = []
        val_losses = []

        for epoch in range(kvs['args'].n_epochs):
            kvs.update('cur_epoch', epoch)

            print(colored('====> ', 'red') + 'Learning rate: ', str(scheduler.get_lr())[1:-1])
            train_loss = utilities.train_epoch(net, optimizer, train_loader)
            val_out = utilities.validate_epoch(net, val_loader)
            val_loss, preds, gt, val_acc = val_out

            print('Epoch: ', epoch)
            print('Acc: ', val_acc)

            metrics.log_metrics(writers[fold_id], train_loss, val_loss, gt, preds)

            session.save_checkpoint(net, 'val_loss', 'lt')  # lt, less than
            
            scheduler.step()

            gc.collect()
Example #21
0
    for train_batch, (images, target) in enumerate(train_iterator):

        images = images.cuda()

        pred = net(images)

        loss_xx, loss_info = loss_detect(pred, target)

        assert not np.isnan\
            (loss_xx.data.cpu().numpy())

        epoch_loss += loss_xx

        status = '[{0}] lr = {1} batch_loss = {2:.3f} epoch_loss = {3:.3f} '.format(
            epoch + 1,
            scheduler.get_lr()[0], loss_xx.data,
            epoch_loss.data / (train_batch + 1))

        train_iterator.set_description(status)

        for tag, value in loss_info.items():
            logger.add_scalar(tag, value, step)

        loss_xx.backward()

        optimizer.step()
        optimizer.zero_grad()
        step += 1

    if epoch % 1 == 0 and epoch > 30:
        print("Evaluate~~~~~   ")
Example #22
0
def train():
    config = Configures(cfg='train.yml')

    seed = int(config('train', 'RANDOM_SEED'))
    base_lr = float(config('train', 'LR'))
    max_steps = int(config('data', 'SIZE'))
    alpha = float(config('train', 'ALPHA'))
    task = config('train', 'TASK') # 'seg'/'offset'
    batch_size = int(config('train', 'BATCH_SIZE'))
    level1 = config('train', 'level1')
    level2 = config('train', 'level2')
    level3 = config('train', 'level3')
    epochnum = level1 + level2 + level3
    milestone = int(config('train', 'MILESTONE'))
    gamma = float(config('train', 'GAMMA'))

    os.environ["CUDA_VISIBLE_DEVICES"] = config('train', 'WORKERS')
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    models_path = os.path.abspath(os.path.expanduser('models'))
    os.makedirs(models_path, exist_ok=True)
    net, starting_epoch = init_net01(config=config)

    voc_loader = VOCloader.Loader(configure=config)
    train_loader = voc_loader()

    optimizer = optim.Adam(net.parameters(), lr=base_lr)
    seg_optimizer = optim.Adam(net.module.segmenter.parameters(), lr=base_lr)
    scheduler = MultiStepLR(optimizer,
                            milestones=[x * milestone for x in range(1, 1000)],
                            gamma=gamma)
    cls_criterion = nn.BCEWithLogitsLoss()
    ''' Losses tested for offsetmap
    mse_loss = nn.MSELoss()
    l1_loss = nn.L1Loss()
    d2_loss = (lambda a, b:
                torch.sum(
                    torch.sqrt(
                        torch.pow(a[:, 0, :, :] - b[:, 0, :, :], 2)
                        + torch.pow(a[:, 1, :, :] - b[:, 1, :, :], 2))))
    '''
    smthL1_criterion = nn.SmoothL1Loss()
    seg_criterion = nn.NLLLoss2d()
    nll_criterion = nn.NLLLoss2d()

    curepoch = 0
    global_loss = []
    
    for epoch in range(starting_epoch, starting_epoch + epochnum):
        curepoch += 1

        epoch_losses = []
        epoch_ins_losses = []
        epoch_cls_losses = []

        train_iterator = tqdm(train_loader, total=max_steps // batch_size + 1)
        steps = 0

        net.train()
        for x, y, y_cls, y_seg in train_iterator:
            steps += batch_size
            x = Variable(x.cuda())
            y = Variable(y.cuda())
            y_cls = Variable(y_cls.cuda())
            y_seg = Variable(y_seg.cuda())

            # optimizer.zero_grad()

            if curepoch <= level1:
                optimizer.zero_grad()
                out_cls = net(x, function='classification')
                loss = torch.abs(cls_criterion(out_cls, y_cls))
                epoch_losses.append(loss.data[0])

                if curepoch == level1:
                    net.module.transfer_weight()

                status = '[{0}] Classification; loss:{1:0.6f}/{2:0.6f}, LR:{3:0.8f}'.format(
                    epoch + 1,
                    loss.data[0],
                    np.mean(epoch_losses),
                    scheduler.get_lr()[0])
                loss.backward()
                optimizer.step()
            
            elif curepoch <= level1 + level2:
                seg_optimizer.zero_grad()
                out_segment = net(x, function='segmentation')
                loss = torch.abs(seg_criterion(out_segment, y_seg))
                epoch_losses.append(loss.data[0])

                status = '[{0}] Segmentation; loss:{1:0.6f}/{2:0.6f}, LR:{3:0.8f}'.format(
                    epoch + 1,
                    loss.data[0],
                    np.mean(epoch_losses),
                    scheduler.get_lr()[0])
                loss.backward()
                seg_optimizer.step()            

            elif curepoch <= level1 + level2 + level3:
                optimizer.zero_grad()
                out_cls, out_segment = net(x, function='classification + segmentation')
                loss = alpha * torch.abs(cls_criterion(out_cls, y_cls)) + torch.abs(seg_criterion(out_segment, y_seg))
                epoch_losses.append(loss.data[0])

                status = '[{0}] Double; loss:{1:0.6f}/{2:0.6f}, LR:{3:0.8f}'.format(
                    epoch + 1,
                    loss.data[0],
                    np.mean(epoch_losses),
                    scheduler.get_lr()[0])
                loss.backward()
                optimizer.step()

            train_iterator.set_description(status)
            # loss.backward()
            # optimizer.step()
        if curepoch <= level1 or curepoch > level1 + level2:
            scheduler.step()
        torch.save(net.state_dict(), os.path.join(models_path, '_'.join(["dense", task, str(epoch + 1)])))
        
        global_loss.append((curepoch, loss.data[0]))

    with open('train.log', 'w') as f:
        f.write(str(global_loss))
Example #23
0
# In[ ]:


criterion = nn.CrossEntropyLoss()


# In[ ]:


best_prec1 = 0
best_epoch = 0
for epoch in range(last_epoch+1, opt.TRAIN.EPOCHS+1):
    logger.info('-'*50)
    lr_scheduler.step(epoch)
    logger.info('lr: {}'.format(lr_scheduler.get_lr()))
    train(train_loader, model, criterion, optimizer, epoch)
    prec1 = validate(test_loader, model, criterion)
    is_best = prec1 > best_prec1
    best_prec1 = max(prec1, best_prec1)
    if is_best:
        best_epoch = epoch
        
    if epoch % opt.TRAIN.SAVE_FREQ == 0:
        save_checkpoint({
            'epoch': epoch,
            'arch': opt.MODEL.ARCH,
            'state_dict': model.module.state_dict(),
            'best_prec1': best_prec1,
            'prec1': prec1,
            'optimizer' : optimizer.state_dict(),
Example #24
0
def main():
    global min_loss
    # phase argument
    mode = sys.argv[1]
    if len(sys.argv) == 4:
        save_weight_path = sys.argv[2]
        train_epoch = int(sys.argv[3])
    elif len(sys.argv) == 5 :
        load_weight_path = sys.argv[2]
        test_data_path = sys.argv[3]
        output_path = sys.argv[4]
    else:
        print("please input correct command\n")

    
    # Hyperparameter
    input_size = 22 # 10 players and a ball location
    hidden_size = 120 # I try it from 50-150, 120 units and 6 layers is the bset
    num_layers = 6 # I try it from 5-12
    fc_size = 50 # fc_size for two layers fc
    batch_size = 100 # the number of train data is 1700, so I try 25, 50, 100, 170
    learning_rate = 0.002 # init learning rate

    lock_epoch = 1000
    use_history_data = True


    if mode == "train":
        if train_epoch != 1:
            train_epoch = lock_epoch
        if use_history_data:
            train_set= pickle.load(open("train_data.p", "rb"))
            test_set= pickle.load(open("test_data.p", "rb"))
        else:
            train_set, test_set = get_set(batch_size, "my_train.p")

        net = m.lSTM(input_size, hidden_size, num_layers, batch = batch_size, FC_size = fc_size)
        # load pre-trained model for fine-tuning
        if os.path.isfile("weight_00117"):
            net.load_state_dict(torch.load("weight_00117"))
            train_epoch = 100
            learning_rate = 0.00001
        # load data
        start_time = time.time()
        test_loader = get_testloader(test_set, batch_size = batch_size)
        end_time = time.time() - start_time
        print("--------------------------\nload data time: {}min {}s".format(end_time//60, int(end_time%60)))
        # train settings
        criterion = torch.nn.MSELoss() # MSELoss is good for regression
        '''
        I try SGD, Adagrad, Adam, and Adam is the best choice with lr=0.002
        I try lr = 0.5, 0.1, 0.05, 0.01, 0.004, 0.002, 0.001 and select the best one
        no weight_decay will be better in my dev-set loss(a little strange)
        lr adjustment is also important, this multistep param is the best one from which I use
        '''
        optimizer = torch.optim.Adam(net.parameters(), lr = learning_rate) # weight_decay=0.0001 
        lr_scheduler = MultiStepLR(optimizer, milestones = [100, 300, 800], gamma = 0.2)
        # train
        for epoch in range(train_epoch):
            lr_scheduler.step()
            lr = lr_scheduler.get_lr()[0]
            train_loader =get_trainloader(train_set, batch_size = batch_size)
            train(epoch, lr, net, criterion, optimizer, train_loader, batch_size=batch_size)
            cur_loss = test(epoch, net, criterion, save_weight_path, test_loader)
            if cur_loss < min_loss: # select the best model by dev-set loss
                print('save the best model')
                torch.save(net.cpu().state_dict(), save_weight_path)
                min_loss = cur_loss
        print("global min loss:{}".format(min_loss))
    elif mode == "test":
        net = m.lSTM(input_size, hidden_size, num_layers, batch=1, FC_size=fc_size)
        evaluate(net, load_weight_path, test_data_path, output_path)
Example #25
0
def fit(model_z,
        train,
        test,
        val=None,
        training_params=None,
        predict_params=None,
        validation_params=None,
        export_params=None,
        optim_params=None,
        model_selection_params=None):
    """
    This function is the core of an experiment. It performs the ml procedure as well as the call to validation.
    :param training_params: parameters for the training procedure
    :param val: validation set
    :param test: the test set
    :param train: The training set
    :param optim_params:
    :param export_params:
    :param validation_params:
    :param predict_params:
    :param model_z: the model that should be trained
    :param model_selection_params:
    """
    # configuration

    training_params, predict_params, validation_params, export_params, optim_params, \
        cv_params = merge_dict_set(
            training_params, TRAINING_PARAMS,
            predict_params, PREDICT_PARAMS,
            validation_params, VALIDATION_PARAMS,
            export_params, EXPORT_PARAMS,
            optim_params, OPTIM_PARAMS,
            model_selection_params, MODEL_SELECTION_PARAMS
        )

    train_loader, test_loader, val_loader = _dataset_setup(
        train, test, val, **training_params)

    statistics_path = output_path('metric_statistics.dump')

    metrics_stats = Statistics(
        model_z, statistics_path, **
        cv_params) if cv_params.pop('cross_validation') else None

    validation_path = output_path('validation.txt')

    # training parameters
    optim = optim_params.pop('optimizer')
    iterations = training_params.pop('iterations')
    gamma = training_params.pop('gamma')
    loss = training_params.pop('loss')
    log_modulo = training_params.pop('log_modulo')
    val_modulo = training_params.pop('val_modulo')
    first_epoch = training_params.pop('first_epoch')

    # callbacks for ml tests
    vcallback = validation_params.pop(
        'vcallback') if 'vcallback' in validation_params else None

    if iterations is None:
        print_errors(
            'Iterations must be set',
            exception=TrainingConfigurationException('Iterations is None'))

    # before ml callback
    if vcallback is not None and special_parameters.train and first_epoch < max(
            iterations):
        init_callbacks(vcallback, val_modulo,
                       max(iterations) // val_modulo, train_loader.dataset,
                       model_z)

    max_iterations = max(iterations)

    if special_parameters.train and first_epoch < max(iterations):
        print_h1('Training: ' + special_parameters.setup_name)

        loss_logs = [] if first_epoch < 1 else load_loss('loss_train')

        loss_val_logs = [] if first_epoch < 1 else load_loss('loss_validation')

        opt = create_optimizer(model_z.parameters(), optim, optim_params)

        scheduler = MultiStepLR(opt, milestones=list(iterations), gamma=gamma)

        # number of batches in the ml
        epoch_size = len(train_loader)

        # one log per epoch if value is -1
        log_modulo = epoch_size if log_modulo == -1 else log_modulo

        epoch = 0
        for epoch in range(max_iterations):

            if epoch < first_epoch:
                # opt.step()
                _skip_step(scheduler, epoch)
                continue
            # saving epoch to enable restart
            export_epoch(epoch)
            model_z.train()

            # printing new epoch
            print_h2('-' * 5 + ' Epoch ' + str(epoch + 1) + '/' +
                     str(max_iterations) + ' (lr: ' + str(scheduler.get_lr()) +
                     ') ' + '-' * 5)

            running_loss = 0.0

            for idx, data in enumerate(train_loader):

                # get the inputs
                inputs, labels = data

                # wrap labels in Variable as input is managed through a decorator
                # labels = model_z.p_label(labels)
                if use_gpu():
                    labels = labels.cuda()

                # zero the parameter gradients
                opt.zero_grad()
                outputs = model_z(inputs)
                loss_value = loss(outputs, labels)
                loss_value.backward()

                opt.step()

                # print math
                running_loss += loss_value.item()
                if idx % log_modulo == log_modulo - 1:  # print every log_modulo mini-batches
                    print('[%d, %5d] loss: %.5f' %
                          (epoch + 1, idx + 1, running_loss / log_modulo))

                    # tensorboard support
                    add_scalar('Loss/train', running_loss / log_modulo)
                    loss_logs.append(running_loss / log_modulo)
                    running_loss = 0.0

            # end of epoch update of learning rate scheduler
            scheduler.step(epoch + 1)

            # saving the model and the current loss after each epoch
            save_checkpoint(model_z, optimizer=opt)

            # validation of the model
            if epoch % val_modulo == val_modulo - 1:
                validation_id = str(int((epoch + 1) / val_modulo))

                # validation call
                predictions, labels, loss_val = predict(
                    model_z, val_loader, loss, **predict_params)
                loss_val_logs.append(loss_val)

                res = '\n[validation_id:' + validation_id + ']\n' + validate(
                    predictions,
                    labels,
                    validation_id=validation_id,
                    statistics=metrics_stats,
                    **validation_params)

                # save statistics for robust cross validation
                if metrics_stats:
                    metrics_stats.save()

                print_notification(res)

                if special_parameters.mail == 2:
                    send_email(
                        'Results for XP ' + special_parameters.setup_name +
                        ' (epoch: ' + str(epoch + 1) + ')', res)
                if special_parameters.file:
                    save_file(
                        validation_path,
                        'Results for XP ' + special_parameters.setup_name +
                        ' (epoch: ' + str(epoch + 1) + ')', res)

                # checkpoint
                save_checkpoint(model_z,
                                optimizer=opt,
                                validation_id=validation_id)

                # callback
                if vcallback is not None:
                    run_callbacks(vcallback, (epoch + 1) // val_modulo)

            # save loss
            save_loss(
                {  # // log_modulo * log_modulo in case log_modulo does not divide epoch_size
                    'train': (loss_logs, log_modulo),
                    'validation':
                    (loss_val_logs,
                     epoch_size // log_modulo * log_modulo * val_modulo)
                },
                ylabel=str(loss))

        # saving last epoch
        export_epoch(epoch +
                     1)  # if --restart is set, the train will not be executed

        # callback
        if vcallback is not None:
            finish_callbacks(vcallback)

    # final validation
    if special_parameters.evaluate or special_parameters.export:
        print_h1('Validation/Export: ' + special_parameters.setup_name)
        if metrics_stats is not None:
            # change the parameter states of the model to best model
            metrics_stats.switch_to_best_model()

        predictions, labels, val_loss = predict(model_z,
                                                test_loader,
                                                loss,
                                                validation_size=-1,
                                                **predict_params)

        if special_parameters.evaluate:

            res = validate(predictions,
                           labels,
                           statistics=metrics_stats,
                           **validation_params,
                           final=True)

            print_notification(res, end='')

            if special_parameters.mail >= 1:
                send_email(
                    'Final results for XP ' + special_parameters.setup_name,
                    res)
            if special_parameters.file:
                save_file(
                    validation_path,
                    'Final results for XP ' + special_parameters.setup_name,
                    res)

        if special_parameters.export:
            export_results(test_loader.dataset, predictions, **export_params)

    return metrics_stats
Example #26
0
            iter_time += iter_timer.toc(average=False)
            load_timer.tic()

        duration = outer_timer.toc(average=False)
        logging.info("epoch {}: {} seconds; Path: {}".format(
            epoch, duration, opt.expr_dir))
        # logging.info("load/iter/cuda: {} vs {} vs {} seconds; iter: {}".format(load_time, iter_time, net.cudaTimer.tot_time, net.cudaTimer.calls))
        # net.cudaTimer.tot_time = 0
        logging.info("load/iter: {} vs {} seconds;".format(
            load_time, iter_time))

        if epoch >= 5 and epoch % save_model_interval == 0:
            save_name = os.path.join(opt.expr_dir, '%06d.h5' % epoch)
            network.save_net(save_name, net)

        if scheduler != None:
            scheduler.step()
            logging.info("lr for next epoch: {}".format(scheduler.get_lr()))

        logging.info("Train loss: {}".format(
            train_loss / data_loader_train.get_num_samples()))

        if opt.use_tensorboard:
            try:
                vis_exp.add_scalar_value('train_loss',
                                         train_loss /
                                         data_loader_train.get_num_samples(),
                                         step=epoch)
            except:
                pass
Example #27
0
def main():
    args = get_args()
    setup_logger('{}/log-train'.format(args.dir), args.log_level)
    logging.info(' '.join(sys.argv))

    if torch.cuda.is_available() == False:
        logging.error('No GPU detected!')
        sys.exit(-1)

    # WARNING(fangjun): we have to select GPU at the very
    # beginning; otherwise you will get trouble later
    kaldi.SelectGpuDevice(device_id=args.device_id)
    kaldi.CuDeviceAllowMultithreading()
    device = torch.device('cuda', args.device_id)

    den_fst = fst.StdVectorFst.Read(args.den_fst_filename)

    # TODO(fangjun): pass these options from commandline
    opts = chain.ChainTrainingOptions()
    opts.l2_regularize = 5e-4
    opts.leaky_hmm_coefficient = 0.1

    den_graph = chain.DenominatorGraph(fst=den_fst, num_pdfs=args.output_dim)

    model = get_chain_model(feat_dim=args.feat_dim,
                            output_dim=args.output_dim,
                            lda_mat_filename=args.lda_mat_filename,
                            hidden_dim=args.hidden_dim,
                            kernel_size_list=args.kernel_size_list,
                            stride_list=args.stride_list)

    start_epoch = 0
    num_epochs = args.num_epochs
    learning_rate = args.learning_rate
    best_objf = -100000

    if args.checkpoint:
        start_epoch, learning_rate, best_objf = load_checkpoint(
            args.checkpoint, model)
        logging.info(
            'loaded from checkpoint: start epoch {start_epoch}, '
            'learning rate {learning_rate}, best objf {best_objf}'.format(
                start_epoch=start_epoch,
                learning_rate=learning_rate,
                best_objf=best_objf))

    model.to(device)

    dataloader = get_egs_dataloader(egs_dir=args.cegs_dir,
                                    egs_left_context=args.egs_left_context,
                                    egs_right_context=args.egs_right_context)

    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           weight_decay=args.l2_regularize)

    scheduler = MultiStepLR(optimizer, milestones=[1, 2, 3, 4, 5], gamma=0.5)
    criterion = KaldiChainObjfFunction.apply

    tf_writer = SummaryWriter(log_dir='{}/tensorboard'.format(args.dir))

    best_epoch = start_epoch
    best_model_path = os.path.join(args.dir, 'best_model.pt')
    best_epoch_info_filename = os.path.join(args.dir, 'best-epoch-info')
    try:
        for epoch in range(start_epoch, args.num_epochs):
            learning_rate = scheduler.get_lr()[0]
            logging.info('epoch {}, learning rate {}'.format(
                epoch, learning_rate))
            tf_writer.add_scalar('learning_rate', learning_rate, epoch)

            objf = train_one_epoch(dataloader=dataloader,
                                   model=model,
                                   device=device,
                                   optimizer=optimizer,
                                   criterion=criterion,
                                   current_epoch=epoch,
                                   opts=opts,
                                   den_graph=den_graph,
                                   tf_writer=tf_writer)
            scheduler.step()

            if best_objf is None:
                best_objf = objf
                best_epoch = epoch

            # the higher, the better
            if objf > best_objf:
                best_objf = objf
                best_epoch = epoch
                save_checkpoint(filename=best_model_path,
                                model=model,
                                epoch=epoch,
                                learning_rate=learning_rate,
                                objf=objf)
                save_training_info(filename=best_epoch_info_filename,
                                   model_path=best_model_path,
                                   current_epoch=epoch,
                                   learning_rate=learning_rate,
                                   objf=best_objf,
                                   best_objf=best_objf,
                                   best_epoch=best_epoch)

            # we always save the model for every epoch
            model_path = os.path.join(args.dir, 'epoch-{}.pt'.format(epoch))
            save_checkpoint(filename=model_path,
                            model=model,
                            epoch=epoch,
                            learning_rate=learning_rate,
                            objf=objf)

            epoch_info_filename = os.path.join(args.dir,
                                               'epoch-{}-info'.format(epoch))
            save_training_info(filename=epoch_info_filename,
                               model_path=model_path,
                               current_epoch=epoch,
                               learning_rate=learning_rate,
                               objf=objf,
                               best_objf=best_objf,
                               best_epoch=best_epoch)

    except KeyboardInterrupt:
        # save the model when ctrl-c is pressed
        model_path = os.path.join(args.dir,
                                  'epoch-{}-interrupted.pt'.format(epoch))
        # use a very small objf for interrupted model
        objf = -100000
        save_checkpoint(model_path,
                        model=model,
                        epoch=epoch,
                        learning_rate=learning_rate,
                        objf=objf)

        epoch_info_filename = os.path.join(
            args.dir, 'epoch-{}-interrupted-info'.format(epoch))
        save_training_info(filename=epoch_info_filename,
                           model_path=model_path,
                           current_epoch=epoch,
                           learning_rate=learning_rate,
                           objf=objf,
                           best_objf=best_objf,
                           best_epoch=best_epoch)

    tf_writer.close()
    logging.warning('Done')
Example #28
0
            img, gt, gt_cls = img, gt.long(), gt_cls.float()

            # Forward pass
            out, out_cls = net(img)
            seg_loss, cls_loss = seg_criterion(out, gt), cls_criterion(
                out_cls, gt_cls)
            loss = seg_loss + args.alpha * cls_loss

            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Log
            epoch_losses.append(loss.item())
            status = '[{0}] step = {1}/{2}, loss = {3:0.4f} avg = {4:0.4f}, LR = {5:0.7f}'.format(
                epoch, count + 1, len(train_loader), loss.item(),
                np.mean(epoch_losses),
                scheduler.get_lr()[0])
            print(status)

        scheduler.step()
        if epoch % 10 == 0:
            torch.save(
                net.state_dict(),
                os.path.join(models_path, '_'.join(["PSPNet",
                                                    str(epoch)])))

    torch.save(net.state_dict(),
               os.path.join(models_path, '_'.join(["PSPNet", 'last'])))
Example #29
0
def train_model_multistage_lowlight():

    device = DEVICE
    #准备数据
    train_set = HsiCubicTrainDataset('./data/train_lowlight_patchsize32/')
    #print('trainset32 training example:', len(train_set32))

    #train_set_64 = HsiCubicTrainDataset('./data/train_lowlight_patchsize64/')

    #train_set_list = [train_set32, train_set_64]
    #train_set = ConcatDataset(train_set_list) #里面的样本大小必须是一致的,否则会连接失败
    print('total training example:', len(train_set))

    train_loader = DataLoader(dataset=train_set,
                              batch_size=BATCH_SIZE,
                              shuffle=True)

    #加载测试label数据
    mat_src_path = './data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat'
    test_label_hsi = scio.loadmat(mat_src_path)['label']

    #加载测试数据
    batch_size = 1
    test_data_dir = './data/test_lowlight/cubic/'
    test_set = HsiCubicLowlightTestDataset(test_data_dir)
    test_dataloader = DataLoader(dataset=test_set,
                                 batch_size=batch_size,
                                 shuffle=False)

    batch_size, channel, width, height = next(iter(test_dataloader))[0].shape

    band_num = len(test_dataloader)
    denoised_hsi = np.zeros((width, height, band_num))

    #创建模型
    net = MultiStageHSID(K)
    init_params(net)
    #net = nn.DataParallel(net).to(device)
    net = net.to(device)

    #创建优化器
    #hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE, betas=(0.9, 0,999))
    hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE)
    scheduler = MultiStepLR(hsid_optimizer, milestones=[40, 60, 80], gamma=0.1)

    #定义loss 函数
    #criterion = nn.MSELoss()

    global tb_writer
    tb_writer = get_summary_writer(log_dir='logs')

    gen_epoch_loss_list = []

    cur_step = 0

    first_batch = next(iter(train_loader))

    best_psnr = 0
    best_epoch = 0
    best_iter = 0
    start_epoch = 1
    num_epoch = 100

    for epoch in range(start_epoch, num_epoch + 1):
        epoch_start_time = time.time()
        scheduler.step()
        print(epoch, 'lr={:.6f}'.format(scheduler.get_last_lr()[0]))
        print(scheduler.get_lr())
        gen_epoch_loss = 0

        net.train()
        #for batch_idx, (noisy, label) in enumerate([first_batch] * 300):
        for batch_idx, (noisy, cubic, label) in enumerate(train_loader):
            #print('batch_idx=', batch_idx)
            noisy = noisy.to(device)
            label = label.to(device)
            cubic = cubic.to(device)

            hsid_optimizer.zero_grad()
            #denoised_img = net(noisy, cubic)
            #loss = loss_fuction(denoised_img, label)

            residual = net(noisy, cubic)
            #loss = loss_fuction(residual, label-noisy)
            loss = np.sum([
                loss_fuction(residual[j], label) for j in range(len(residual))
            ])
            loss.backward()  # calcu gradient
            hsid_optimizer.step()  # update parameter

            gen_epoch_loss += loss.item()

            if cur_step % display_step == 0:
                if cur_step > 0:
                    print(
                        f"Epoch {epoch}: Step {cur_step}: Batch_idx {batch_idx}: MSE loss: {loss.item()}"
                    )
                else:
                    print("Pretrained initial state")

            tb_writer.add_scalar("MSE loss", loss.item(), cur_step)

            #step ++,每一次循环,每一个batch的处理,叫做一个step
            cur_step += 1

        gen_epoch_loss_list.append(gen_epoch_loss)
        tb_writer.add_scalar("mse epoch loss", gen_epoch_loss, epoch)

        #scheduler.step()
        #print("Decaying learning rate to %g" % scheduler.get_last_lr()[0])

        torch.save(
            {
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
            }, f"checkpoints/hsid_multistage_patchsize64_{epoch}.pth")

        #测试代码
        net.eval()
        for batch_idx, (noisy_test, cubic_test,
                        label_test) in enumerate(test_dataloader):
            noisy_test = noisy_test.type(torch.FloatTensor)
            label_test = label_test.type(torch.FloatTensor)
            cubic_test = cubic_test.type(torch.FloatTensor)

            noisy_test = noisy_test.to(DEVICE)
            label_test = label_test.to(DEVICE)
            cubic_test = cubic_test.to(DEVICE)

            with torch.no_grad():

                residual = net(noisy_test, cubic_test)
                denoised_band = noisy_test + residual[0]

                denoised_band_numpy = denoised_band.cpu().numpy().astype(
                    np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:, :, batch_idx] = denoised_band_numpy

                if batch_idx == 49:
                    residual_squeezed = torch.squeeze(residual[0], axis=0)
                    denoised_band_squeezed = torch.squeeze(denoised_band,
                                                           axis=0)
                    label_test_squeezed = torch.squeeze(label_test, axis=0)
                    noisy_test_squeezed = torch.squeeze(noisy_test, axis=0)
                    tb_writer.add_image(f"images/{epoch}_restored",
                                        denoised_band_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_residual",
                                        residual_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_label",
                                        label_test_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_noisy",
                                        noisy_test_squeezed,
                                        1,
                                        dataformats='CHW')

        psnr = PSNR(denoised_hsi, test_label_hsi)
        ssim = SSIM(denoised_hsi, test_label_hsi)
        sam = SAM(denoised_hsi, test_label_hsi)

        #计算pnsr和ssim
        print("=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".
              format(psnr, ssim, sam))
        tb_writer.add_scalars("validation metrics", {
            'average PSNR': psnr,
            'average SSIM': ssim,
            'avarage SAM': sam
        }, epoch)  #通过这个我就可以看到,那个epoch的性能是最好的

        #保存best模型
        if psnr > best_psnr:
            best_psnr = psnr
            best_epoch = epoch
            best_iter = cur_step
            torch.save(
                {
                    'epoch': epoch,
                    'gen': net.state_dict(),
                    'gen_opt': hsid_optimizer.state_dict(),
                }, f"checkpoints/hsid_multistage_patchsize64_best.pth")

        print(
            "[epoch %d it %d PSNR: %.4f --- best_epoch %d best_iter %d Best_PSNR %.4f]"
            % (epoch, cur_step, psnr, best_epoch, best_iter, best_psnr))

        print(
            "------------------------------------------------------------------"
        )
        print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".
              format(epoch,
                     time.time() - epoch_start_time, gen_epoch_loss,
                     scheduler.get_lr()[0]))
        print(
            "------------------------------------------------------------------"
        )

        #保存当前模型
        torch.save(
            {
                'epoch': epoch,
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict()
            }, os.path.join('./checkpoints', "model_latest.pth"))
    tb_writer.close()
Example #30
0
def _train(dataset_name: str, backbone_name: str, path_to_data_dir: str, path_to_checkpoints_dir: str, path_to_resuming_checkpoint: Optional[str]):
    dataset = DatasetBase.from_name(dataset_name)(path_to_data_dir, DatasetBase.Mode.TRAIN, Config.IMAGE_MIN_SIDE, Config.IMAGE_MAX_SIDE)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=8, pin_memory=True)

    Log.i('Found {:d} samples'.format(len(dataset)))

    backbone = BackboneBase.from_name(backbone_name)(pretrained=True)
    model = Model(backbone, dataset.num_classes(), pooling_mode=Config.POOLING_MODE,
                  anchor_ratios=Config.ANCHOR_RATIOS, anchor_scales=Config.ANCHOR_SCALES,
                  rpn_pre_nms_top_n=Config.RPN_PRE_NMS_TOP_N, rpn_post_nms_top_n=Config.RPN_POST_NMS_TOP_N).cuda()
    optimizer = optim.SGD(model.parameters(), lr=Config.LEARNING_RATE,
                          momentum=Config.MOMENTUM, weight_decay=Config.WEIGHT_DECAY)
    scheduler = MultiStepLR(optimizer, milestones=Config.STEP_LR_SIZES, gamma=Config.STEP_LR_GAMMA)

    step = 0
    time_checkpoint = time.time()
    losses = deque(maxlen=100)
    summary_writer = SummaryWriter(os.path.join(path_to_checkpoints_dir, 'summaries'))
    should_stop = False

    num_steps_to_display = Config.NUM_STEPS_TO_DISPLAY
    num_steps_to_snapshot = Config.NUM_STEPS_TO_SNAPSHOT
    num_steps_to_finish = Config.NUM_STEPS_TO_FINISH

    if path_to_resuming_checkpoint is not None:
        step = model.load(path_to_resuming_checkpoint, optimizer, scheduler)
        Log.i(f'Model has been restored from file: {path_to_resuming_checkpoint}')

    Log.i('Start training')

    while not should_stop:
        for batch_index, (_, image_batch, _, bboxes_batch, labels_batch) in enumerate(dataloader):
            assert image_batch.shape[0] == 1, 'only batch size of 1 is supported'

            image = image_batch[0].cuda()
            bboxes = bboxes_batch[0].cuda()
            labels = labels_batch[0].cuda()

            forward_input = Model.ForwardInput.Train(image, gt_classes=labels, gt_bboxes=bboxes)
            forward_output: Model.ForwardOutput.Train = model.train().forward(forward_input)

            anchor_objectness_loss, anchor_transformer_loss, proposal_class_loss, proposal_transformer_loss = forward_output
            loss = anchor_objectness_loss + anchor_transformer_loss + proposal_class_loss + proposal_transformer_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            losses.append(loss.item())
            summary_writer.add_scalar('train/anchor_objectness_loss', anchor_objectness_loss.item(), step)
            summary_writer.add_scalar('train/anchor_transformer_loss', anchor_transformer_loss.item(), step)
            summary_writer.add_scalar('train/proposal_class_loss', proposal_class_loss.item(), step)
            summary_writer.add_scalar('train/proposal_transformer_loss', proposal_transformer_loss.item(), step)
            summary_writer.add_scalar('train/loss', loss.item(), step)
            step += 1

            if step == num_steps_to_finish:
                should_stop = True

            if step % num_steps_to_display == 0:
                elapsed_time = time.time() - time_checkpoint
                time_checkpoint = time.time()
                steps_per_sec = num_steps_to_display / elapsed_time
                samples_per_sec = dataloader.batch_size * steps_per_sec
                eta = (num_steps_to_finish - step) / steps_per_sec / 3600
                avg_loss = sum(losses) / len(losses)
                lr = scheduler.get_lr()[0]
                Log.i(f'[Step {step}] Avg. Loss = {avg_loss:.6f}, Learning Rate = {lr:.6f} ({samples_per_sec:.2f} samples/sec; ETA {eta:.1f} hrs)')

            if step % num_steps_to_snapshot == 0 or should_stop:
                path_to_checkpoint = model.save(path_to_checkpoints_dir, step, optimizer, scheduler)
                Log.i(f'Model has been saved to {path_to_checkpoint}')

            if should_stop:
                break

    Log.i('Done')