示例#1
0
    def select_action(self, state, steps_done):
        util.adjust_learning_rate(self.optimizer, self.lr, steps_done, 10000)
        # global steps_done
        sample = random.random()
        esp_threshold = self.EPS_END + (self.EPS_START - self.EPS_END) * \
                                  np.exp(-1. * steps_done / self.EPS_DECAY)

        if sample > esp_threshold:
            actions = self.get_actions(state)
            action = actions.data.max(1)[1].view(1, 1)
            return action
        else:
            return self.LongTensor([[random.randrange(self.action_dim)]])
示例#2
0
def main():

    args = parse_option()

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.dataset == 'CelebA':
        train_dataset = CelebAPrunedAligned_MAFLVal(args.data_folder,
                                                    train=True,
                                                    pair_image=True,
                                                    do_augmentations=True,
                                                    imwidth=args.image_size,
                                                    crop=args.image_crop)
    elif args.dataset == 'InatAve':
        train_dataset = InatAve(args.data_folder,
                                train=True,
                                pair_image=True,
                                do_augmentations=True,
                                imwidth=args.image_size,
                                imagelist=args.imagelist)
    else:
        raise NotImplementedError('dataset not supported {}'.format(
            args.dataset))

    print(len(train_dataset))
    train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    # create model and optimizer
    n_data = len(train_dataset)

    input_size = args.image_size - 2 * args.image_crop
    pool_size = int(input_size /
                    2**5)  # 96x96 --> 3; 160x160 --> 5; 224x224 --> 7;

    if args.model == 'resnet50':
        model = InsResNet50(pool_size=pool_size)
        model_ema = InsResNet50(pool_size=pool_size)
    elif args.model == 'resnet50x2':
        model = InsResNet50(width=2, pool_size=pool_size)
        model_ema = InsResNet50(width=2, pool_size=pool_size)
    elif args.model == 'resnet50x4':
        model = InsResNet50(width=4, pool_size=pool_size)
        model_ema = InsResNet50(width=4, pool_size=pool_size)
    elif args.model == 'resnet18':
        model = InsResNet18(width=1, pool_size=pool_size)
        model_ema = InsResNet18(width=1, pool_size=pool_size)
    elif args.model == 'resnet34':
        model = InsResNet34(width=1, pool_size=pool_size)
        model_ema = InsResNet34(width=1, pool_size=pool_size)
    elif args.model == 'resnet101':
        model = InsResNet101(width=1, pool_size=pool_size)
        model_ema = InsResNet101(width=1, pool_size=pool_size)
    elif args.model == 'resnet152':
        model = InsResNet152(width=1, pool_size=pool_size)
        model_ema = InsResNet152(width=1, pool_size=pool_size)
    else:
        raise NotImplementedError('model not supported {}'.format(args.model))

    # copy weights from `model' to `model_ema'
    moment_update(model, model_ema, 0)

    # set the contrast memory and criterion
    contrast = MemoryMoCo(128,
                          n_data,
                          args.nce_k,
                          args.nce_t,
                          use_softmax=True).cuda(args.gpu)

    criterion = NCESoftmaxLoss()
    criterion = criterion.cuda(args.gpu)

    model = model.cuda()
    model_ema = model_ema.cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    cudnn.benchmark = True

    # optionally resume from a checkpoint
    args.start_epoch = 1
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            # checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            contrast.load_state_dict(checkpoint['contrast'])
            model_ema.load_state_dict(checkpoint['model_ema'])
            print("=> loaded successfully '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            del checkpoint
            torch.cuda.empty_cache()
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # tensorboard
    logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2)

    # routine
    for epoch in range(args.start_epoch, args.epochs + 1):

        adjust_learning_rate(epoch, args, optimizer)
        print("==> training...")

        time1 = time.time()
        loss, prob = train_moco(epoch, train_loader, model, model_ema,
                                contrast, criterion, optimizer, args)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        # tensorboard logger
        logger.log_value('ins_loss', loss, epoch)
        logger.log_value('ins_prob', prob, epoch)
        logger.log_value('learning_rate', optimizer.param_groups[0]['lr'],
                         epoch)

        # save model
        if epoch % args.save_freq == 0:
            print('==> Saving...')
            state = {
                'opt': args,
                'model': model.state_dict(),
                'contrast': contrast.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
            }
            state['model_ema'] = model_ema.state_dict()
            save_file = os.path.join(
                args.model_folder,
                'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)
            # help release GPU memory
            del state

        # saving the model
        print('==> Saving...')
        state = {
            'opt': args,
            'model': model.state_dict(),
            'contrast': contrast.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch,
        }
        state['model_ema'] = model_ema.state_dict()
        save_file = os.path.join(args.model_folder, 'current.pth')
        torch.save(state, save_file)
        if epoch % args.save_freq == 0:
            save_file = os.path.join(
                args.model_folder,
                'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)
        # help release GPU memory
        del state
        torch.cuda.empty_cache()
示例#3
0
def train(encoder, decoder, train_loader, args):
    if cfg.CUDA:
        encoder = encoder.cuda()
        decoder = decoder.cuda()

    exp_dir = args.save_root
    save_model_path = os.path.join(exp_dir, 'models')
    if not os.path.exists(save_model_path):
        os.makedirs(save_model_path)
    epoch = args.start_epoch
    if epoch != 0:
        decoder.load_state_dict(
            torch.load("%s/models/WE_decoder%d.pth" % (exp_dir, epoch)))
        encoder.load_state_dict(
            torch.load("%s/models/WE_encoder%d.pth" % (exp_dir, epoch)))
        print('loaded parametres from epoch %d' % epoch)

    trainables_en = [p for p in encoder.parameters() if p.requires_grad]
    trainables_de = [p for p in decoder.parameters() if p.requires_grad]
    trainables = trainables_en + trainables_de

    if args.optim == 'sgd':
        optimizer = torch.optim.SGD(trainables,
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

    elif args.optim == 'adam':
        optimizer = torch.optim.Adam(trainables,
                                     args.lr,
                                     weight_decay=args.weight_decay,
                                     betas=(0.95, 0.999))
    else:
        raise ValueError('Optimizer %s is not supported' % args.optim)

    criterion_mse = nn.MSELoss()
    criterion_bce = nn.BCELoss()

    save_file = os.path.join(exp_dir, 'results_WD.txt')
    while epoch <= args.epoch:
        loss_meter = AverageMeter()
        epoch += 1
        adjust_learning_rate(args.lr, args.lr_decay, optimizer, epoch)

        encoder.train()
        decoder.train()

        for i, (audio, mask, length) in enumerate(train_loader):
            loss = 0
            word_len = mask.sum(1)
            word_len = word_len.float().cuda()
            audio = audio.float().cuda()
            mask = mask.float().cuda()
            optimizer.zero_grad()

            audio_features, word_nums = encoder(audio, mask, length)
            recons_audio = decoder(audio_features, mask, length)
            loss = reconstruction_loss(audio, recons_audio, length, args)
            loss.backward()
            optimizer.step()

            loss_meter.update(loss.item(), args.batch_size)

            if i % 5 == 0:
                print('iteration = %d | loss = %f ' % (i, loss))

        if epoch % 5 == 0:
            torch.save(encoder.state_dict(),
                       "%s/models/WE_encoder%d.pth" % (exp_dir, epoch))
            torch.save(decoder.state_dict(),
                       "%s/models/WE_decoder%d.pth" % (exp_dir, epoch))
            # true_num,total_predict,real_num,P = evaluation(encoder,decoder,val_loader,val_image_loader,args)
            # info = "epoch {} | loss {:.2f} | True {} | Predict {} | Real {} | P {:.2%}\n".format(epoch,loss,true_num,total_predict,real_num,P)
            info = "epoch {} | loss {:.2f} \n".format(epoch, loss_meter.avg)
            print(info)
            with open(save_file, 'a') as f:
                f.write(info)
示例#4
0
def main():
    args = get_args()
    log_folder = os.path.join('train_log', args.name)
    writer = SummaryWriter(log_folder)

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # number of classes for each dataset.
    if args.dataset == 'PascalVOC':
        num_classes = 21
    elif args.dataset == 'COCO':
        num_classes = 81
    else:
        raise Exception("No dataset named {}.".format(args.dataset))

    # Select Model & Method
    model = models.__dict__[args.arch](num_classes=num_classes)

    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)

    # Optimizer
    optimizer = torch.optim.SGD(
        [{
            'params': get_parameters(model, bias=False, final=False),
            'lr': args.lr,
            'weight_decay': args.wd
        }, {
            'params': get_parameters(model, bias=True, final=False),
            'lr': args.lr * 2,
            'weight_decay': 0
        }, {
            'params': get_parameters(model, bias=False, final=True),
            'lr': args.lr * 10,
            'weight_decay': args.wd
        }, {
            'params': get_parameters(model, bias=True, final=True),
            'lr': args.lr * 20,
            'weight_decay': 0
        }],
        momentum=args.momentum)

    if args.resume:
        model = load_model(model, args.resume)

    train_loader = data_loader(args)
    data_iter = iter(train_loader)
    train_t = tqdm(range(args.max_iter))
    model.train()
    for global_iter in train_t:
        try:
            images, target, gt_map = next(data_iter)
        except:
            data_iter = iter(data_loader(args))
            images, target, gt_map = next(data_iter)

        if args.gpu is not None:
            images = images.cuda(args.gpu)
            gt_map = gt_map.cuda(args.gpu)
            target = target.cuda(args.gpu)

        output = model(images)

        fc8_SEC_softmax = softmax_layer(output)
        loss_s = seed_loss_layer(fc8_SEC_softmax, gt_map)
        loss_e = expand_loss_layer(fc8_SEC_softmax, target, num_classes - 1)
        fc8_SEC_CRF_log = crf_layer(output, images, iternum=10)
        loss_c = constrain_loss_layer(fc8_SEC_softmax, fc8_SEC_CRF_log)

        loss = loss_s + loss_e + loss_c

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # writer add_scalars
        writer.add_scalar('loss', loss, global_iter)
        writer.add_scalars('losses', {
            'loss_s': loss_s,
            'loss_e': loss_e,
            'loss_c': loss_c
        }, global_iter)

        with torch.no_grad():
            if global_iter % 10 == 0:
                # writer add_images (origin, output, gt)
                origin = images.clone().detach() + torch.tensor(
                    [123., 117., 107.]).reshape(1, 3, 1, 1).cuda(args.gpu)

                size = (100, 100)
                origin = F.interpolate(origin, size=size)
                origins = vutils.make_grid(origin,
                                           nrow=15,
                                           padding=2,
                                           normalize=True,
                                           scale_each=True)

                outputs = F.interpolate(output, size=size)
                _, outputs = torch.max(outputs, dim=1)
                outputs = outputs.unsqueeze(1)
                outputs = vutils.make_grid(outputs,
                                           nrow=15,
                                           padding=2,
                                           normalize=True,
                                           scale_each=True).float()

                gt_maps = F.interpolate(gt_map, size=size)
                _, gt_maps = torch.max(gt_maps, dim=1)
                gt_maps = gt_maps.unsqueeze(1)
                gt_maps = vutils.make_grid(gt_maps,
                                           nrow=15,
                                           padding=2,
                                           normalize=True,
                                           scale_each=True).float()

                # gt_maps = F.interpolate(gt_map.unsqueeze(1).float(), size=size)
                # gt_maps = vutils.make_grid(gt_maps, nrow=15, padding=2, normalize=True, scale_each=True).float()

                grid_image = torch.cat((origins, outputs, gt_maps), dim=1)
                writer.add_image(args.name, grid_image, global_iter)


        description = '[{0:4d}/{1:4d}] loss: {2} s: {3} e: {4} c: {5}'.\
            format(global_iter+1, args.max_iter, loss, loss_s, loss_e, loss_c)
        train_t.set_description(desc=description)

        # save snapshot
        if global_iter % args.snapshot == 0:
            save_checkpoint(model.state_dict(), log_folder,
                            'checkpoint_%d.pth.tar' % global_iter)

        # lr decay
        if global_iter % args.lr_decay == 0:
            args.lr = args.lr * 0.1
            optimizer = adjust_learning_rate(optimizer, args.lr)

    print("Training is over...")
    save_checkpoint(model.state_dict(), log_folder, 'last_checkpoint.pth.tar')
示例#5
0
def train(config,model,optimizer):
    """Train the model
    Params:
        config:
            json config data
        model:
            model to train
        optimizer:
            optimizer used in training
    Return:
        None
    """
    data_loaders = get_data_loaders()
    ensure_dir(config['log_dir'])
    t_begin = time.time()
    best_acc, old_file = 0, None
    history = {'train':{'loss':[],'acc':[]},'val':{'loss':[],'acc':[]}}
    for epoch in range(config['epoch_num']):
        model.train() # train phase
        epoch_loss = 0
        epoch_correct = 0
        adjust_learning_rate(config,optimizer,epoch)
        for batch_idx, (data,target) in enumerate(data_loaders['train']):
            indx_target = target.clone()
            data, target = Variable(data.cuda()),Variable(target.cuda())
            optimizer.zero_grad()
            output = model(data)
            #define your own loss function here
            loss = F.cross_entropy(output,target)
            epoch_loss += loss.data[0]
            loss.backward()
            optimizer.step()
            pred = output.data.max(1)[1]
            correct = pred.cpu().eq(indx_target).sum()
            epoch_correct += correct
            
            if config['batch_log'] and batch_idx % config['batch_log_interval'] == 0 and batch_idx > 0:
                acc = correct * 1.0 / len(data)
                print('Train Epoch: {} [{}/{}] Batch_Loss: {:.6f} Batch_Acc: {:.4f} lr: {:.2e}'.format(
                    epoch, batch_idx * len(data), len(data_loaders['train'].dataset),
                    loss.data[0], acc, optimizer.param_groups[0]['lr']))
        elapse_time = time.time() - t_begin
        speed_epoch = elapse_time / (epoch + 1)
        speed_batch = speed_epoch / len(data_loaders['train'])
        eta = speed_epoch * config['epoch_num'] - elapse_time
        print("{}/{} Elapsed {:.2f}s, {:.2f} s/epoch, {:.2f} s/batch, ets {:.2f}s".format(epoch+1,
            config['epoch_num'],elapse_time, speed_epoch, speed_batch, eta))
        
        
        epoch_loss = epoch_loss / len(data_loaders['train']) # average over number of mini-batch
        acc = 100. * epoch_correct / len(data_loaders['train'].dataset)
        print('\tTrain set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
                epoch_loss, epoch_correct, len(data_loaders['train'].dataset), acc))
        history['train']['loss'].append(epoch_loss)  
        history['train']['acc'].append(acc)          
        model_snapshot(model, os.path.join(config['log_dir'], 'latest.pth'))
        
        if epoch % config['val_interval'] == 0:
            model.eval()
            val_loss = 0
            correct = 0
            for data, target in data_loaders['val']:
                indx_target = target.clone()
                data, target = Variable(data.cuda(),volatile=True),  Variable(target.cuda())
                output = model(data)
                val_loss += F.cross_entropy(output, target).data[0]
                pred = output.data.max(1)[1]  # get the index of the max log-probability
                correct += pred.cpu().eq(indx_target).sum()

            val_loss = val_loss / len(data_loaders['val']) # average over number of mini-batch
            acc = 100. * correct / len(data_loaders['val'].dataset)
            print('\tVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
                val_loss, correct, len(data_loaders['val'].dataset), acc))
            history['val']['loss'].append(val_loss)  
            history['val']['acc'].append(acc)
            if acc > best_acc:
                new_file = os.path.join(config['log_dir'], datetime.now().strftime('%Y-%m-%d-%H-%M-%S')+'-best-{}.pth'.format(epoch))
                model_snapshot(model, new_file, old_file=old_file, verbose=True)
                best_acc = acc
                old_file = new_file
    f = open(config['history'],'wb')
    try:
        pickle.dump(history,f)
    finally:
        f.close()
    print("Total Elapse: {:.2f}s, Best Val Acc: {:.3f}%".format(time.time()-t_begin, best_acc))
示例#6
0
def train_model(model, criterion, optimizer, dataloaders, model_path, start_epoch, iter_num, logger, device):
    since = time.time()
    best_loss = np.inf
    best_map = 0
    trial_log = args.trial_log
    num_epochs = args.num_epochs
    test_interval = args.test_interval
    burn_in = args.burn_in
    lr = args.learning_rate
    lr_steps = args.lr_steps
    size_grid_cell = args.size_grid_cell
    num_boxes = args.num_boxes
    num_classes = args.num_classes
    conf_thresh = args.conf_thresh
    iou_thresh = args.iou_thresh
    nms_thresh = args.nms_thresh
    port = args.port
    vis = Visualizer(env=trial_log, port=port)
        
    for epoch in range(start_epoch, num_epochs):
        logger.info('Epoch {} / {}'.format(epoch+1, num_epochs))
        logger.info('-' * 64)

        # set learning rate manually
        if epoch in lr_steps:
            lr *= 0.1
        adjust_learning_rate(optimizer, lr)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                # scheduler.step()
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            total_loss = 0.0
            # Iterate over data.
            for i, (inputs, targets) in enumerate(dataloaders[phase]):
                # warmming up of the learning rate
                if phase == 'train':
                    if iter_num < args.burn_in:
                        burn_lr = get_learning_rate(iter_num, lr, burn_in)
                        adjust_learning_rate(optimizer, burn_lr)
                        iter_num += 1
                    else:
                        adjust_learning_rate(optimizer, lr)
                    
                inputs = inputs.to(device)
                targets = targets.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss, obj_coord_loss, obj_conf_loss, noobj_conf_loss, obj_class_loss = criterion(outputs, targets)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                total_loss += loss.item()

                if phase == 'train':
                    cur_lr = optimizer.state_dict()['param_groups'][0]['lr']
                    vis.plot('cur_lr', cur_lr)
                    logger.info('Epoch [{}/{}], iter [{}/{}], lr: {:g}, loss: {:.4f}, average_loss: {:.4f}'.format(
                        epoch+1, args.num_epochs, i+1, len(dataloaders[phase]), cur_lr, loss.item(), total_loss/(i+1)))
                    logger.debug('  obj_coord_loss: {:.4f}, obj_conf_loss: {:.4f}, noobj_conf_loss: {:.4f}, obj_class_loss: {:.4f}'.format(
                        obj_coord_loss, obj_conf_loss, noobj_conf_loss, obj_class_loss))
                    vis.plot('train_loss', total_loss/(i+1))

            # save model for inferencing and resuming training process
            if phase == 'train':
                torch.save(model.state_dict(), osp.join(model_path, 'latest.pth'))
                torch.save({
                    'iter_num: ': iter_num,
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }, osp.join(model_path, 'latest.tar'))

            # evaluate latest model
            if phase == 'val':
                current_loss = total_loss / (i+1)
                if best_loss > current_loss:
                    best_loss = current_loss
                logger.info('current val loss: {:.4f}, best val Loss: {:.4f}'.format(current_loss, best_loss))
                vis.plot('val_loss', total_loss/(i+1))

                if epoch < 10 or (epoch+1) % test_interval == 0:
                    current_map = calc_map(logger, dataloaders[phase].dataset, model_path, 
                        size_grid_cell, num_boxes, num_classes, conf_thresh, iou_thresh, nms_thresh)
                    # save the best model as so far
                    if best_map < current_map:
                        best_map = current_map
                        torch.save(model.state_dict(), osp.join(model_path, 'best.pth'))
                    logger.info('current val map: {:.4f}, best val map: {:.4f}'.format(current_map, best_map))
                    vis.plot('val_map', current_map)

    time_elapsed = time.time() - since
    logger.info('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    logger.info('Optimization Done.')
def main():

    global best_error
    best_error = np.Inf

    args = parse_option()
    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))
    torch.manual_seed(0)

    # train on celebA unlabeled dataset
    train_dataset = CelebAPrunedAligned_MAFLVal(args.data_folder, 
                                                train=True, 
                                                pair_image=False, 
                                                do_augmentations=True,
                                                imwidth=args.image_size,
                                                crop = args.image_crop)
    print('Number of training images: %d' % len(train_dataset))
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.num_workers, pin_memory=True, sampler=None)

    
    # validation set from MAFLAligned trainset for hyperparameter searching
    # we sample 2000 images as our val set
    val_dataset   = MAFLAligned(args.data_folder, 
                                train=True, # train set
                                pair_image=True, 
                                do_augmentations=False,
                                TPS_aug = True, 
                                imwidth=args.image_size, 
                                crop=args.image_crop)
    print('Initial number of validation images: %d' % len(val_dataset)) 
    val_dataset.restrict_annos(num=2000, outpath=args.save_folder, repeat_flag=False)
    print('After restricting the size of validation set: %d' % len(val_dataset))
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=2, shuffle=False,
        num_workers=args.num_workers, pin_memory=True)


    # testing set from MAFLAligned test for evaluating image matching
    test_dataset = MAFLAligned(args.data_folder, 
                               train=False, # test set 
                               pair_image=True, 
                               do_augmentations=False,
                               TPS_aug = True, # match landmark between deformed images
                               imwidth=args.image_size, 
                               crop=args.image_crop)
    print('Number of testing images: %d' % len(test_dataset)) 
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=2, shuffle=False,
        num_workers=args.num_workers, pin_memory=True)

    assert len(val_dataset) == 2000
    assert len(test_dataset) == 1000


    # create model and optimizer
    input_size = args.image_size - 2 * args.image_crop
    pool_size = int(input_size / 2**5) # 96x96 --> 3; 160x160 --> 5; 224x224 --> 7;
    # we use smaller feature map when training the feature distiller for memory issue
    args.train_output_shape = (args.train_out_size, args.train_out_size)
    # we use the original size of the image (e.g. 96x96 face images) during testing
    args.val_output_shape = (args.val_out_size, args.val_out_size)

    if args.model == 'resnet50':
        model = InsResNet50(pool_size=pool_size)
        desc_dim = {1:64, 2:256, 3:512, 4:1024, 5:2048}
    elif args.model == 'resnet50_half':
        model = InsResNet50(width=0.5, pool_size=pool_size)
        desc_dim = {1:int(64/2), 2:int(256/2), 3:int(512/2), 4:int(1024/2), 5:int(2048/2)}
    elif args.model == 'resnet50x2':
        model = InsResNet50(width=2, pool_size=pool_size)
        desc_dim = {1:128, 2:512, 3:1024, 4:2048, 5:4096}
    elif args.model == 'resnet50x4':
        model = InsResNet50(width=4, pool_size=pool_size)
        desc_dim = {1:512, 2:1024, 3:2048, 4:4096, 5:8192}
    elif args.model == 'resnet18':
        model = InsResNet18(width=1, pool_size=pool_size)
        desc_dim = {1:64, 2:64, 3:128, 4:256, 5:512}
    elif args.model == 'resnet34':
        model = InsResNet34(width=1, pool_size=pool_size)
        desc_dim = {1:64, 2:64, 3:128, 4:256, 5:512}
    elif args.model == 'resnet101':
        model = InsResNet101(width=1, pool_size=pool_size)
        desc_dim = {1:64, 2:256, 3:512, 4:1024, 5:2048}
    elif args.model == 'resnet152':
        model = InsResNet152(width=1, pool_size=pool_size)
        desc_dim = {1:64, 2:256, 3:512, 4:1024, 5:2048}
    elif args.model == 'hourglass':
        model = HourglassNet()
    else:
        raise NotImplementedError('model not supported {}'.format(args.model))
    
    
    # xxx_feat_spectral records the feat dim per layer in hypercol
    # this information is useful to do layer-wise feat normalization in landmark matching
    train_feat_spectral = [] 
    if args.train_use_hypercol:
        for i in range(args.train_layer):
            train_feat_spectral.append(desc_dim[5-i])
    else:
        train_feat_spectral.append(desc_dim[args.train_layer])
    args.train_feat_spectral = train_feat_spectral

    val_feat_spectral = []
    if args.val_use_hypercol:
        for i in range(args.val_layer):
            val_feat_spectral.append(desc_dim[5-i])
    else:
        val_feat_spectral.append(desc_dim[args.val_layer])
    args.val_feat_spectral = val_feat_spectral
    

    # load pretrained moco 
    if args.trained_model_path != 'none':
        print('==> loading pre-trained model')
        ckpt = torch.load(args.trained_model_path, map_location='cpu')
        model.load_state_dict(ckpt['model'], strict=True)
        print("==> loaded checkpoint '{}' (epoch {})".format(
                            args.trained_model_path, ckpt['epoch']))
        print('==> done')
    else:
        print('==> use randomly initialized model')


    # Define feature distiller, set pretrained model to eval mode
    if args.feat_distill:
        model.eval()
        assert np.sum(train_feat_spectral) == np.sum(val_feat_spectral)
        feat_distiller = FeatDistiller(np.sum(val_feat_spectral), 
                                        kernel_size=args.kernel_size,
                                        mode=args.distill_mode,
                                        out_dim = args.out_dim,
                                        softargmax_mul=args.softargmax_mul)
        feat_distiller = nn.DataParallel(feat_distiller)
        feat_distiller.train()
        print('Feature distillation is used: kernel_size:{}, mode:{}, out_dim:{}'.format(
                args.kernel_size, args.distill_mode, args.out_dim))
        feat_distiller = feat_distiller.cuda()
    else:
        feat_distiller = None


    #  evaluate feat distiller on landmark matching, given pretrained moco and feature distiller
    model = model.cuda()
    if args.evaluation_mode:
        if args.feat_distill:
            print("==> use pretrained feature distiller ...")
            feat_ckpt = torch.load(args.trained_feat_model_path, map_location='cpu') 
            # in below, feat_distiller is misspelt, but to use pretrained model, I keep it.
            feat_distiller.load_state_dict(feat_ckpt['feat_disiller'], strict=False)
            print("==> loaded checkpoint '{}' (epoch {})".format(
                                args.trained_feat_model_path, feat_ckpt['epoch']))
            same_err, diff_err = validate(test_loader, model, args, 
                                        feat_distiller=feat_distiller, 
                                        visualization=args.visualize_matching)
        else:
            print("==> use hypercolumn ...")
            same_err, diff_err = validate(test_loader, model, args, 
                                        feat_distiller=None, 
                                        visualization=args.visualize_matching)
        exit()


    ## define optimizer for feature distiller  
    if not args.adam:
        if not args.feat_distill:
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=args.learning_rate,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
        else:
            optimizer = torch.optim.SGD(feat_distiller.parameters(),
                                        lr=args.learning_rate,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
    else:
        if not args.feat_distill:
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=args.learning_rate,
                                         betas=(args.beta1, args.beta2),
                                         weight_decay=args.weight_decay,
                                         eps=1e-8,
                                         amsgrad=args.amsgrad)
        else:
            optimizer = torch.optim.Adam(feat_distiller.parameters(),
                                         lr=args.learning_rate,
                                         betas=(args.beta1, args.beta2),
                                         weight_decay=args.weight_decay,
                                         eps=1e-8,
                                         amsgrad=args.amsgrad)



    # set lr scheduler
    if args.cosine: # we use cosine scheduler by default
        eta_min = args.learning_rate * (args.lr_decay_rate ** 3) * 0.1
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min, -1)
    elif args.multistep:
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 250], gamma=0.1)

    # tensorboard
    logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2)
    cudnn.benchmark = True

    # report the performance of hypercol on landmark matching tasks
    print("==> Testing of initial model on validation set...")
    same_err, diff_err = validate(val_loader, model, args, feat_distiller=None)
    print("==> Testing of initial model on test set...")
    same_err, diff_err = validate(test_loader, model, args, feat_distiller=None)
    
    # training loss for feature projector
    criterion = dense_corr_loss

    # training feature distiller
    for epoch in range(1, args.epochs + 1):
        if args.cosine or args.multistep:
            scheduler.step()
        else:
            adjust_learning_rate(epoch, args, optimizer)

        print("==> training ...")
        time1 = time.time()
        train_loss = train_point_contrast(epoch, train_loader, model, criterion, optimizer, args, 
                                            feat_distiller=feat_distiller)
        time2 = time.time()
        print('train epoch {}, total time {:.2f}, train_loss {:.4f}'.format(epoch, 
                                            time2 - time1, train_loss))
        logger.log_value('train_loss', train_loss, epoch)
        logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch)


        print("==> validation ...")
        val_same_err, val_diff_err = validate(val_loader, model, args, 
                                            feat_distiller=feat_distiller)


        print("==> testing ...")
        test_same_err, test_diff_err = validate(test_loader, model, args, 
                                            feat_distiller=feat_distiller)

        # save model
        if epoch % args.save_freq == 0:
            print('==> Saving...')
            state = {
                'opt': args,
                'epoch': epoch,
                'feat_disiller': feat_distiller.state_dict(),
                'val_error': [val_same_err, val_diff_err],
                'test_error': [test_same_err, test_diff_err],
            }
            save_name = 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)
            save_name = os.path.join(args.save_folder, save_name)
            print('saving regular model!')
            torch.save(state, save_name)

            if val_diff_err < best_error:
                best_error = val_diff_err
                save_name = 'best.pth'
                save_name = os.path.join(args.save_folder, save_name)
                print('saving best model! val_same: {} val_diff: {} test_same: {} test_diff: {}'.format(val_same_err, val_diff_err, test_same_err, test_diff_err))
                torch.save(state, save_name)
示例#8
0
def main():
    global args, best_EPE
    args = parser.parse_args()
    save_path = '{},{},b{},lr{}'.format(args.arch, args.solver,
                                        args.batch_size, args.lr)
    if not args.no_date:
        timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")
        save_path = os.path.join(timestamp, save_path)
    save_path = os.path.join(args.dataset, save_path)
    print('=> will save everything to {}'.format(save_path))
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    train_writer = SummaryWriter(os.path.join(save_path, 'train'))
    test_writer = SummaryWriter(os.path.join(save_path, 'test'))
    output_writers = []
    for i in range(3):
        output_writers.append(
            SummaryWriter(os.path.join(save_path, 'test', str(i))))

    # Data loading code
    input_transform = transforms.Compose([
        flow_transforms.ArrayToTensor(),
        transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]),
        transforms.Normalize(mean=[0.411, 0.432, 0.45], std=[1, 1, 1])
    ])
    target_transform = transforms.Compose([
        flow_transforms.ArrayToTensor(),
        transforms.Normalize(mean=[0, 0], std=[args.div_flow, args.div_flow])
    ])

    if 'KITTI' in args.dataset:
        args.sparse = True
    if args.sparse:
        co_transform = flow_transforms.Compose([
            flow_transforms.RandomCrop((320, 448)),
            flow_transforms.RandomVerticalFlip(),
            flow_transforms.RandomHorizontalFlip()
        ])
    else:
        co_transform = flow_transforms.Compose([
            flow_transforms.RandomTranslate(10),
            flow_transforms.RandomRotate(10, 5),
            flow_transforms.RandomCrop((320, 448)),
            flow_transforms.RandomVerticalFlip(),
            flow_transforms.RandomHorizontalFlip()
        ])

    print("=> fetching img pairs in '{}'".format(args.data))
    train_set, test_set = datasets.__dict__[args.dataset](
        args.data,
        transform=input_transform,
        target_transform=target_transform,
        co_transform=co_transform,
        split=args.split_file if args.split_file else args.split_value)
    print('{} samples found, {} train samples and {} test samples '.format(
        len(test_set) + len(train_set), len(train_set), len(test_set)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               shuffle=True)
    val_loader = torch.utils.data.DataLoader(test_set,
                                             batch_size=args.batch_size,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             shuffle=False)

    # create model
    if args.pretrained:
        network_data = torch.load(args.pretrained)
        args.arch = network_data['arch']
        print("=> using pre-trained model '{}'".format(args.pretrained))
    else:
        network_data = None
        print("=> creating model '{}'".format(args.arch))

    model = models.__dict__[args.arch](network_data)
    bias_params = model.bias_parameters()
    weight_params = model.weight_parameters()
    parallel = False

    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            model = torch.nn.DataParallel(model)
            parallel = True
        model = model.to(device)
        cudnn.benchmark = True

        if parallel is True:
            bias_params = model.module.bias_parameters()
            weight_params = model.module.weight_parameters()

    assert (args.solver in ['adam', 'sgd'])
    print('=> setting {} solver'.format(args.solver))
    param_groups = [{
        'params': bias_params,
        'weight_decay': args.bias_decay
    }, {
        'params': weight_params,
        'weight_decay': args.weight_decay
    }]
    if args.solver == 'adam':
        optimizer = torch.optim.Adam(param_groups,
                                     args.lr,
                                     betas=(args.momentum, args.beta))
    elif args.solver == 'sgd':
        optimizer = torch.optim.SGD(param_groups,
                                    args.lr,
                                    momentum=args.momentum)

    if args.evaluate:
        best_EPE = validate(val_loader, model, 0, output_writers)
        return

    epoch_size = len(train_loader)
    # device_num = torch.cuda.device_count() if torch.cuda.is_available() else 1
    # epochs = 70000 // device_num // epoch_size + 1 if args.epochs == -1 else args.epochs
    epochs = 70000 // epoch_size + 1 if args.epochs == -1 else args.epochs

    for epoch in range(args.start_epoch, epochs):
        adjust_learning_rate(args, optimizer, epoch, epoch_size)

        # train for one epoch
        train_loss, train_EPE = train(train_loader, model, optimizer, epoch,
                                      epoch_size, train_writer)
        train_writer.add_scalar('mean EPE', train_EPE, epoch)

        # evaluate on validation set
        with torch.no_grad():
            EPE = validate(val_loader, model, epoch, output_writers)
        test_writer.add_scalar('mean EPE', EPE, epoch)

        if best_EPE < 0:
            best_EPE = EPE

        is_best = EPE < best_EPE
        best_EPE = min(EPE, best_EPE)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.module.state_dict(),
                'best_EPE': best_EPE,
                'div_flow': args.div_flow
            }, is_best, save_path)
示例#9
0
def main():
    args = get_args()

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # number of classes for each dataset.
    if args.dataset == 'PascalVOC':
        num_classes = 20
    else:
        raise Exception("No dataset named {}.".format(args.dataset))

    # Select Model & Method
    model = models.__dict__[args.arch](pretrained=args.pretrained,
                                       num_classes=num_classes)

    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)

    # define loss function (criterion) and optimizer
    criterion = nn.MultiLabelSoftMarginLoss().cuda(args.gpu)
    # criterion = nn.BCEWithLogitsLoss().cuda(args.gpu)

    # Take apart parameters to give different Learning Rate
    param_features = []
    param_classifiers = []

    if args.arch.startswith('vgg'):
        for name, parameter in model.named_parameters():
            if 'features.' in name:
                param_features.append(parameter)
            else:
                param_classifiers.append(parameter)
    elif args.arch.startswith('resnet'):
        for name, parameter in model.named_parameters():
            if 'layer4.' in name or 'fc.' in name:
                param_classifiers.append(parameter)
            else:
                param_features.append(parameter)
    else:
        raise Exception("Fail to recognize the architecture")

    # Optimizer
    optimizer = torch.optim.SGD([
        {'params': param_features, 'lr': args.lr},
        {'params': param_classifiers, 'lr': args.lr * args.lr_ratio}],
        momentum=args.momentum,
        weight_decay=args.weight_decay,
        nesterov=args.nest)

    # optionally resume from a checkpoint
    if args.resume:
        model, optimizer = load_model(model, optimizer, args)
    train_loader, val_loader, test_loader = data_loader(args)

    saving_dir = os.path.join(args.log_folder, args.name)

    if args.evaluate:
        # test_ap, test_loss = evaluate_cam(val_loader, model, criterion, args)
        # test_ap, test_loss = evaluate_cam2(val_loader, model, criterion, args)
        test_ap, test_loss = evaluate_cam3(val_loader, model, criterion, args)
        print_progress(test_ap, test_loss, 0, 0, prefix='test')
        return

    # Training Phase
    best_m_ap = 0
    for epoch in range(args.start_epoch, args.epochs):

        adjust_learning_rate(optimizer, epoch, args)

        # Train for one epoch
        train_ap, train_loss = \
            train(train_loader, model, criterion, optimizer, epoch, args)
        print_progress(train_ap, train_loss, epoch+1, args.epochs)

        # Evaluate classification
        val_ap, val_loss = validate(val_loader, model, criterion, epoch, args)
        print_progress(val_ap, val_loss, epoch+1, args.epochs, prefix='validation')

        # # Save checkpoint at best performance:
        is_best = val_ap.mean() > best_m_ap
        if is_best:
            best_m_ap = val_ap.mean()

        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_m_ap': best_m_ap,
            'optimizer': optimizer.state_dict(),
        }, is_best, saving_dir)

        save_progress(saving_dir, train_ap, train_loss, val_ap, val_loss, args)
def train(model, train_loader, val_loader, args):
    if cfg.CUDA:
        model = model.cuda()

    exp_dir = args.save_root
    save_model_path = os.path.join(exp_dir, 'models')
    if not os.path.exists(save_model_path):
        os.makedirs(save_model_path)
    epoch = args.start_epoch
    if epoch != 0:
        model.load_state_dict(
            torch.load("%s/models/WBDNet_%d.pth" % (exp_dir, epoch)))
        print('loaded parametres from epoch %d' % epoch)

    trainables = [p for p in model.parameters() if p.requires_grad]

    if args.optim == 'sgd':
        optimizer = torch.optim.SGD(trainables,
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

    elif args.optim == 'adam':
        optimizer = torch.optim.Adam(trainables,
                                     args.lr,
                                     weight_decay=args.weight_decay,
                                     betas=(0.95, 0.999))
    else:
        raise ValueError('Optimizer %s is not supported' % args.optim)

    loss_meter = AverageMeter()
    criterion_mse = nn.MSELoss()
    criterion_bce = nn.BCELoss()

    save_file = os.path.join(exp_dir, 'results.txt')
    while epoch <= args.epoch:
        epoch += 1
        adjust_learning_rate(args.lr, args.lr_decay, optimizer, epoch)

        model.train()

        for i, (audio, target, mask, length,
                cap_ID) in enumerate(train_loader):
            loss = 0

            audio = audio.float().cuda()
            target = target.float().cuda()
            mask = mask.cuda()
            optimizer.zero_grad()

            predict = model(audio)
            criterion_bce_log = nn.BCEWithLogitsLoss(
                pos_weight=(args.bce_weight * target + 1.0) * mask)
            loss = criterion_bce_log(predict, target)

            loss.backward()
            optimizer.step()

            loss_meter.update(loss.item(), args.batch_size)

            if i % 100 == 0:
                print('iteration = %d | loss = %f ' % (i, loss))
        if epoch % 1 == 0:
            torch.save(model.state_dict(),
                       "%s/models/WBDNet_%d.pth" % (exp_dir, epoch))
            metrics = evaluation(model, val_loader, args)
            Recall = metrics['recall']
            P = metrics['precision']
            F1 = metrics['F1']
            info = "epoch {} | loss {:.2f} | Recall {:.2%} | Precision {:.2%} | F1 {:.2%}\n".format(
                epoch, loss, Recall, P, F1)
            print(info)
            with open(save_file, 'a') as f:
                f.write(info)
示例#11
0
def main():

    global best_error
    best_error = np.Inf

    args = parse_option()

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    train_dataset = getattr(module_data, args.dataset)(args.data_folder, 
                                         train=True, 
                                         pair_image=False, 
                                         do_augmentations=False,
                                         imwidth=args.image_size, 
                                         crop=args.image_crop,
                                         TPS_aug=args.TPS_aug) # using TPS for data augmentation
    val_dataset   = getattr(module_data, args.dataset)(args.data_folder, 
                                         train=False, 
                                         pair_image=False, 
                                         do_augmentations=False,
                                         imwidth=args.image_size, 
                                         crop=args.image_crop)

    print('Number of training images: %d' % len(train_dataset))
    print('Number of validation images: %d' % len(val_dataset))


    # for the few-shot experiments: using limited annotations to train the landmark regression
    if args.restrict_annos > -1:
        if args.resume:
            train_dataset.restrict_annos(num=args.restrict_annos, datapath=args.save_folder, 
                                                    repeat_flag=args.repeat, num_per_epoch = args.num_per_epoch)
        else:
            train_dataset.restrict_annos(num=args.restrict_annos, outpath=args.save_folder,  
                                                    repeat_flag=args.repeat, num_per_epoch = args.num_per_epoch)
        print('Now restricting number of images to %d, sanity check: %d; number per epoch %d' % (args.restrict_annos, 
                                                    len(train_dataset), args.num_per_epoch))
    

    train_sampler = None
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.num_workers, pin_memory=True, sampler=train_sampler)
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.num_workers, pin_memory=True)

    # create model and optimizer
    input_size = args.image_size - 2 * args.image_crop
    pool_size = int(input_size / 2**5) # 96x96 --> 3; 160x160 --> 5; 224x224 --> 7;
    args.output_shape = (48,48)

    if args.model == 'resnet50':
        model = InsResNet50(pool_size=pool_size)
        desc_dim = {1:64, 2:256, 3:512, 4:1024, 5:2048}
    elif args.model == 'resnet50x2':
        model = InsResNet50(width=2, pool_size=pool_size)
        desc_dim = {1:128, 2:512, 3:1024, 4:2048, 5:4096}
    elif args.model == 'resnet50x4':
        model = InsResNet50(width=4, pool_size=pool_size)
        desc_dim = {1:512, 2:1024, 3:2048, 4:4096, 5:8192}
    elif args.model == 'resnet18':
        model = InsResNet18(width=1, pool_size=pool_size)
        desc_dim = {1:64, 2:64, 3:128, 4:256, 5:512}
    elif args.model == 'resnet34':
        model = InsResNet34(width=1, pool_size=pool_size)
        desc_dim = {1:64, 2:64, 3:128, 4:256, 5:512}
    elif args.model == 'resnet101':
        model = InsResNet101(width=1, pool_size=pool_size)
        desc_dim = {1:64, 2:256, 3:512, 4:1024, 5:2048}
    elif args.model == 'resnet152':
        model = InsResNet152(width=1, pool_size=pool_size)
        desc_dim = {1:64, 2:256, 3:512, 4:1024, 5:2048}
    elif args.model == 'hourglass':
        model = HourglassNet()
    else:
        raise NotImplementedError('model not supported {}'.format(args.model))


    if args.model == 'hourglass':
        feat_dim = 64
    else:
        if args.use_hypercol:
            feat_dim = 0
            for i in range(args.layer):
                feat_dim += desc_dim[5-i]
        else:
            feat_dim = desc_dim[args.layer]

   
    regressor =  IntermediateKeypointPredictor(feat_dim, num_annotated_points=args.num_points, 
                                                num_intermediate_points=50, 
                                                softargmax_mul = 100.0)

    print('==> loading pre-trained model')
    ckpt = torch.load(args.trained_model_path, map_location='cpu')
    model.load_state_dict(ckpt['model'], strict=False)
    print("==> loaded checkpoint '{}' (epoch {})".format(args.trained_model_path, ckpt['epoch']))
    print('==> done')

    model = model.cuda()
    regressor = regressor.cuda()

    criterion = regression_loss

    if not args.adam:
        optimizer = torch.optim.SGD(regressor.parameters(),
                                    lr=args.learning_rate,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.Adam(regressor.parameters(),
                                     lr=args.learning_rate,
                                     betas=(args.beta1, args.beta2),
                                     weight_decay=args.weight_decay,
                                     eps=1e-8,
                                     amsgrad=args.amsgrad)
    model.eval()
    cudnn.benchmark = True

    # optionally resume from a checkpoint
    args.start_epoch = 1
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            # checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            regressor.load_state_dict(checkpoint['regressor'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            best_error = checkpoint['best_error']
            best_error = best_error.cuda()
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
            if 'opt' in checkpoint.keys():
                # resume optimization hyper-parameters
                print('=> resume hyper parameters')
                if 'bn' in vars(checkpoint['opt']):
                    print('using bn: ', checkpoint['opt'].bn)
                if 'adam' in vars(checkpoint['opt']):
                    print('using adam: ', checkpoint['opt'].adam)
                if 'cosine' in vars(checkpoint['opt']):
                    print('using cosine: ', checkpoint['opt'].cosine)
                args.learning_rate = checkpoint['opt'].learning_rate
                # args.lr_decay_epochs = checkpoint['opt'].lr_decay_epochs
                args.lr_decay_rate = checkpoint['opt'].lr_decay_rate
                args.momentum = checkpoint['opt'].momentum
                args.weight_decay = checkpoint['opt'].weight_decay
                args.beta1 = checkpoint['opt'].beta1
                args.beta2 = checkpoint['opt'].beta2
            del checkpoint
            torch.cuda.empty_cache()
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # set cosine annealing scheduler
    if args.cosine:

        # last_epoch = args.start_epoch - 2
        # eta_min = args.learning_rate * (args.lr_decay_rate ** 3) * 0.1
        # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min, last_epoch)

        eta_min = args.learning_rate * (args.lr_decay_rate ** 3) * 0.1
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min, -1)
        # dummy loop to catch up with current epoch
        for i in range(1, args.start_epoch):
            scheduler.step()
    elif args.multistep:
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 250], gamma=0.1)
        # dummy loop to catch up with current epoch
        for i in range(1, args.start_epoch):
            scheduler.step()

    # tensorboard
    logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2)

    # routine
    for epoch in range(args.start_epoch, args.epochs + 1):

        if args.cosine or args.multistep:
            scheduler.step()
        else:
            adjust_learning_rate(epoch, args, optimizer)
        print("==> training...")

        time1 = time.time()
        InterOcularError, train_loss = train(epoch, train_loader, model, regressor, criterion, optimizer, args)
        time2 = time.time()
        print('train epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        logger.log_value('InterOcularError', InterOcularError, epoch)
        logger.log_value('train_loss', train_loss, epoch)
        logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch)

        print("==> testing...")
        test_InterOcularError, test_loss = validate(val_loader, model, regressor, criterion, args)

        logger.log_value('Test_InterOcularError', test_InterOcularError, epoch)
        logger.log_value('test_loss', test_loss, epoch) 

        # save the best model
        if test_InterOcularError < best_error:
            best_error = test_InterOcularError
            state = {
                'opt': args,
                'epoch': epoch,
                'regressor': regressor.state_dict(),
                'best_error': best_error,
                'optimizer': optimizer.state_dict(),
            }
            save_name = '{}.pth'.format(args.model)
            save_name = os.path.join(args.save_folder, save_name)
            print('saving best model!')
            torch.save(state, save_name)

        # save model
        if epoch % args.save_freq == 0:
            print('==> Saving...')
            state = {
                'opt': args,
                'epoch': epoch,
                'regressor': regressor.state_dict(),
                'best_error': test_InterOcularError,
                'optimizer': optimizer.state_dict(),
            }
            save_name = 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)
            save_name = os.path.join(args.save_folder, save_name)
            print('saving regular model!')
            torch.save(state, save_name)

        # tensorboard logger
        pass