예제 #1
0
from utils.opts import Opt
from utils.visualizer import Visualizer

from train import train
from val import val

if __name__ == "__main__":

    opt = Opt().parse()
    ########################################
    #                 Model                #
    ########################################
    torch.manual_seed(opt.manual_seed)

    vis = Visualizer(opt)
    model = get_model(opt)
    if opt.optimizer == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
    elif opt.optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(), lr=opt.lr)
    else:
        NotImplementedError("Only Adam and SGD are supported")
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               'min',
                                               patience=opt.lr_patience)

    ########################################
    #              Transforms              #
    ########################################
    if not opt.no_train:
예제 #2
0
def main():
    
    opts = get_argparser().parse_args()
    
    # Set the number of classes
    if opts.dataset.lower() == 'voc':
        opts.num_classes = 21
    elif opts.dataset.lower() == 'cityscapes':
        opts.num_classes = 19
    elif opts.dataset.lower() == 'coca':
        opts.num_classes = 3

    # Setup visualization
    vis = Visualizer(port=opts.vis_port,
                     env=opts.vis_env) if opts.enable_vis else None
    if vis is not None:  # display options
        vis.vis_table("Options", vars(opts))

    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device: %s" % device)

    # Setup random seed
    torch.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    # Setup dataloader
    if opts.dataset == 'voc' and not opts.crop_val:
        opts.val_batch_size = 1
    
    # Get the dataset
    train_dst, val_dst = get_dataset(opts)
    
    # Create the data loaders
    train_loader = data.DataLoader(train_dst, batch_size=opts.batch_size, shuffle=True, num_workers=2)
    val_loader = data.DataLoader(val_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2)
    print("Dataset: %s, Train set: %d, Val set: %d" % (opts.dataset, len(train_dst), len(val_dst)))
    
    # Set up model
    model_map = {
        'deeplabv3_resnet50': network.deeplabv3_resnet50,
        'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,
        'deeplabv3plus_resnet50_multi_input': network.deeplabv3plus_resnet50_multi_input,
        'deeplabv3_resnet101': network.deeplabv3_resnet101,
        'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,
        'deeplabv3_mobilenet': network.deeplabv3_mobilenet,
        'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet
    }

    model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride)
    if opts.separable_conv and 'plus' in opts.model:
        network.convert_to_separable_conv(model.classifier)
    utils.set_bn_momentum(model.backbone, momentum=0.01)
    
    # Set up metrics
    metrics = StreamSegMetrics(opts.num_classes)

    # Set up optimizer
    optimizer = torch.optim.SGD(params=[
        {'params': model.backbone.parameters(), 'lr': 0.1*opts.lr},
        {'params': model.classifier.parameters(), 'lr': opts.lr},
    ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
    #optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
    #torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor)
    if opts.lr_policy=='poly':
        scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)
    elif opts.lr_policy=='step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1)

    # Set up criterion
    #criterion = utils.get_loss(opts.loss_type)
    if opts.loss_type == 'focal_loss':
        criterion = utils.FocalLoss(ignore_index=255, size_average=True)
    elif opts.loss_type == 'cross_entropy':
        criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')

    def save_ckpt(path):
        """ save current model
        """
        torch.save({
            "cur_itrs": cur_itrs,
            "model_state": model.module.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict(),
            "best_score": best_score,
        }, path)
        print("Model saved as %s" % path)
    
    utils.mkdir('checkpoints')
    # Restore
    best_score = 0.0
    cur_itrs = 0
    cur_epochs = 0
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        # https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan
        checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint["model_state"])
        model = nn.DataParallel(model)
        model.to(device)
        if opts.continue_training:
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            cur_itrs = checkpoint["cur_itrs"]
            best_score = checkpoint['best_score']
            print("Training state restored from %s" % opts.ckpt)
        print("Model restored from %s" % opts.ckpt)
        del checkpoint  # free memory
    else:
        print("[!] Retrain")
        model = nn.DataParallel(model)
        model.to(device)

    #==========   Train Loop   ==========#
    vis_sample_id = np.random.randint(0, len(val_loader), opts.vis_num_samples,
                                      np.int32) if opts.enable_vis else None  # sample idxs for visualization
    denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # denormalization for ori images

    if opts.test_only:
        model.eval()
        val_score, ret_samples = validate(
            opts=opts, model=model, loader=val_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id)
        print(metrics.to_str(val_score))
        return

    interval_loss = 0
    while True: #cur_itrs < opts.total_itrs:
        # =====  Train  =====
        model.train()
        cur_epochs += 1
        for (images, labels) in train_loader:
            cur_itrs += 1

            images = images.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.long)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            np_loss = loss.detach().cpu().numpy()
            interval_loss += np_loss
            if vis is not None:
                vis.vis_scalar('Loss', cur_itrs, np_loss)

            if (cur_itrs) % 10 == 0:
                interval_loss = interval_loss/10
                print("Epoch %d, Itrs %d/%d, Loss=%f" %
                      (cur_epochs, cur_itrs, opts.total_itrs, interval_loss))
                interval_loss = 0.0

            if (cur_itrs) % opts.val_interval == 0:
                save_ckpt('checkpoints/latest_%s_%s_os%d.pth' %
                          (opts.model, opts.dataset, opts.output_stride))
                print("validation...")
                model.eval()
                val_score, ret_samples = validate(
                    opts=opts, model=model, loader=val_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id)
                print(metrics.to_str(val_score))
                if val_score['Mean IoU'] > best_score:  # save best model
                    best_score = val_score['Mean IoU']
                    save_ckpt('checkpoints/best_%s_%s_os%d.pth' %
                              (opts.model, opts.dataset,opts.output_stride))

                if vis is not None:  # visualize validation score and samples
                    vis.vis_scalar("[Val] Overall Acc", cur_itrs, val_score['Overall Acc'])
                    vis.vis_scalar("[Val] Mean IoU", cur_itrs, val_score['Mean IoU'])
                    vis.vis_table("[Val] Class IoU", val_score['Class IoU'])

                    for k, (img, target, lbl) in enumerate(ret_samples):
                        img = (denorm(img) * 255).astype(np.uint8)
                        target = train_dst.decode_target(target).transpose(2, 0, 1).astype(np.uint8)
                        lbl = train_dst.decode_target(lbl).transpose(2, 0, 1).astype(np.uint8)
                        concat_img = np.concatenate((img, target, lbl), axis=2)  # concat along width
                        vis.vis_image('Sample %d' % k, concat_img)
                model.train()
            scheduler.step()  

            if cur_itrs >=  opts.total_itrs:
                return
예제 #3
0
def main():
    opt = TrainOptions().parse()
    train_history = PoseTrainHistory()
    checkpoint = Checkpoint()
    visualizer = Visualizer(opt)
    exp_dir = os.path.join(opt.exp_dir, opt.exp_id)
    log_name = opt.vis_env + 'log.txt'
    visualizer.log_path = os.path.join(exp_dir, log_name)
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id
    # if opt.dataset == 'mpii':
    num_classes = 16
    net = create_hg(num_stacks=2,
                    num_modules=1,
                    num_classes=num_classes,
                    chan=256)
    # num1 = get_n_params(net)
    # num2 = get_n_trainable_params(net)
    # num3 = get_n_conv_params(net)
    # print 'number of params: ', num1
    # print 'number of trainalbe params: ', num2
    # print 'number of conv params: ', num3
    # exit()
    net = torch.nn.DataParallel(net).cuda()
    """optimizer"""
    optimizer = torch.optim.RMSprop(net.parameters(),
                                    lr=opt.lr,
                                    alpha=0.99,
                                    eps=1e-8,
                                    momentum=0,
                                    weight_decay=0)
    """optionally resume from a checkpoint"""
    if opt.load_prefix_pose != '':
        # if 'pth' in opt.resume_prefix:
        #     trunc_index = opt.resume_prefix.index('pth')
        #     opt.resume_prefix = opt.resume_prefix[0:trunc_index - 1]
        checkpoint.save_prefix = os.path.join(exp_dir, opt.load_prefix_pose)
        checkpoint.load_prefix = os.path.join(exp_dir,
                                              opt.load_prefix_pose)[0:-1]
        checkpoint.load_checkpoint(net, optimizer, train_history)
        # trunc_index = checkpoint.save_prefix.index('lr-0.00025-80')
        # checkpoint.save_prefix = checkpoint.save_prefix[0:trunc_index]
        # checkpoint.save_prefix = exp_dir + '/'
    else:
        checkpoint.save_prefix = exp_dir + '/'
    print 'save prefix: ', checkpoint.save_prefix
    # model = {'state_dict': net.state_dict()}
    # save_path = checkpoint.save_prefix + 'test-model-size.pth.tar'
    # torch.save(model, save_path)
    # exit()
    """load data"""
    train_loader = torch.utils.data.DataLoader(MPII(
        'dataset/mpii-hr-lsp-normalizer.json',
        '/bigdata1/zt53/data',
        is_train=True),
                                               batch_size=opt.bs,
                                               shuffle=True,
                                               num_workers=opt.nThreads,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(MPII(
        'dataset/mpii-hr-lsp-normalizer.json',
        '/bigdata1/zt53/data',
        is_train=False),
                                             batch_size=opt.bs,
                                             shuffle=False,
                                             num_workers=opt.nThreads,
                                             pin_memory=True)

    print type(optimizer), optimizer.param_groups[0]['lr']
    # idx = range(0, 16)
    # idx = [e for e in idx if e not in (6, 7, 8, 9, 12, 13)]
    idx = [0, 1, 2, 3, 4, 5, 10, 11, 14, 15]
    # criterion = torch.nn.MSELoss(size_average=True).cuda()
    if not opt.is_train:
        visualizer.log_path = os.path.join(opt.exp_dir, opt.exp_id,
                                           'val_log.txt')
        val_loss, val_pckh, predictions = validate(
            val_loader, net, train_history.epoch[-1]['epoch'], visualizer, idx,
            num_classes)
        checkpoint.save_preds(predictions)
        return
    """training and validation"""
    start_epoch = 0
    if opt.load_prefix_pose != '':
        start_epoch = train_history.epoch[-1]['epoch'] + 1
    for epoch in range(start_epoch, opt.nEpochs):
        adjust_lr(opt, optimizer, epoch)
        # # train for one epoch
        train_loss, train_pckh = train(train_loader, net, optimizer, epoch,
                                       visualizer, idx, opt)

        # evaluate on validation set
        val_loss, val_pckh, predictions = validate(val_loader, net, epoch,
                                                   visualizer, idx,
                                                   num_classes)
        # visualizer.display_imgpts(imgs, pred_pts, 4)
        # exit()
        # update training history
        e = OrderedDict([('epoch', epoch)])
        lr = OrderedDict([('lr', optimizer.param_groups[0]['lr'])])
        loss = OrderedDict([('train_loss', train_loss),
                            ('val_loss', val_loss)])
        pckh = OrderedDict([('train_pckh', train_pckh),
                            ('val_pckh', val_pckh)])
        train_history.update(e, lr, loss, pckh)
        checkpoint.save_checkpoint(net, optimizer, train_history, predictions)
        visualizer.plot_train_history(train_history)
예제 #4
0
def main():
    global args, best_loss
    best_loss = sys.float_info.max

    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    model = blendNet()
    model = torch.nn.DataParallel(model).cuda()
    visualizer = Visualizer(args)

    cudnn.benchmark = True
    kwargs = {
        'num_workers': 10,
        'pin_memory': True
    } if args.cuda else {}  ##num_workers

    train_loader = torch.utils.data.DataLoader(FeatureLoader(
        '/media/dragonx/DataStorage/download/5label_mean_std/',
        True,
        transform=transforms.Compose([
            transforms.ToTensor(),
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)

    test_loader = torch.utils.data.DataLoader(FeatureLoader(
        '/media/dragonx/DataStorage/download/5label_mean_std/',
        False,
        transform=transforms.Compose([
            transforms.ToTensor(),
        ])),
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              **kwargs)
    criterion = torch.nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(
        filter(lambda p: p.requires_grad,
               model.parameters()),  # Only finetunable params
        args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay)
    # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    if args.evaluate:
        validate(test_loader, model, criterion)
        return
    errors_val_set = []
    errors_val = {}
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args,
              visualizer)
        # evaluate on validation set
        visualizer.reset()
        loss, score = validate(test_loader, model, criterion)
        errors_val['val_loss'] = loss
        errors_val['val_accu'] = score
        errors_val_set.append(errors_val)
        print('evaluation loss is %f at epoch %d' % (loss, epoch))
        # visualizer.plot_current_errors(epoch, float(i)*args.batch_size/35126, args, errors)
        # visualizer.plot_current_errors(epoch, 0, args, errors_val)
        # remember best prec@1 and save checkpoint
        # is_best = prec1 > best_prec1
        # best_prec1 = max(prec1, best_prec1)
        is_best = loss < best_loss
        best_loss = min(loss, best_loss)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_loss': best_loss,
            }, is_best)
예제 #5
0
def main():
    config = "config/cocostuff.yaml"
    cuda = True
    device = torch.device("cuda" if cuda and torch.cuda.is_available() else "cpu")

    if cuda:
        current_device = torch.cuda.current_device()
        print("Running on", torch.cuda.get_device_name(current_device))
    else:
        print("Running on CPU")

    # Configuration
    CONFIG = Dict(yaml.load(open(config)))
    CONFIG.SAVE_DIR = osp.join(CONFIG.SAVE_DIR, CONFIG.EXPERIENT)
    CONFIG.LOGNAME = osp.join(CONFIG.SAVE_DIR, "log.txt")

    # Dataset
    dataset = MultiDataSet(
        CONFIG.ROOT,
        CONFIG.CROPSIZE,
        preload=False
    )

    # DataLoader
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=CONFIG.BATCH_SIZE,
        num_workers=CONFIG.NUM_WORKERS,
        shuffle=True,
    )
    loader_iter = iter(loader)

    # Model
    model = DeepLabV2_ResNet101_MSC(n_classes=CONFIG.N_CLASSES)
    state_dict = torch.load(CONFIG.INIT_MODEL)
    model.load_state_dict(state_dict, strict=False)  # Skip "aspp" layer
    model = nn.DataParallel(model)
    # read old version
    if CONFIG.ITER_START != 1:
        load_network(CONFIG.SAVE_DIR, model, "SateDeepLab", "latest")
        print("load previous model succeed, training start from iteration {}".format(CONFIG.ITER_START))
    model.to(device)

    # Optimizer
    optimizer = {
        "sgd": torch.optim.SGD(
            # cf lr_mult and decay_mult in train.prototxt
            params=[
                {
                    "params": get_lr_params(model.module, key="1x"),
                    "lr": CONFIG.LR,
                    "weight_decay": CONFIG.WEIGHT_DECAY,
                },
                {
                    "params": get_lr_params(model.module, key="10x"),
                    "lr": 10 * CONFIG.LR,
                    "weight_decay": CONFIG.WEIGHT_DECAY,
                },
                {
                    "params": get_lr_params(model.module, key="20x"),
                    "lr": 20 * CONFIG.LR,
                    "weight_decay": 0.0,
                },
            ],
            momentum=CONFIG.MOMENTUM,
        )
    }.get(CONFIG.OPTIMIZER)

    # Loss definition
    criterion = SoftCrossEntropyLoss2d()
    criterion.to(device)

    #visualizer
    vis = Visualizer(CONFIG.DISPLAYPORT)

    model.train()
    model.module.scale.freeze_bn()
    iter_start_time = time.time()
    for iteration in range(CONFIG.ITER_START, CONFIG.ITER_MAX + 1):
        # Set a learning rate
        poly_lr_scheduler(
            optimizer=optimizer,
            init_lr=CONFIG.LR,
            iter=iteration - 1,
            lr_decay_iter=CONFIG.LR_DECAY,
            max_iter=CONFIG.ITER_MAX,
            power=CONFIG.POLY_POWER,
        )

        # Clear gradients (ready to accumulate)
        optimizer.zero_grad()

        iter_loss = 0
        for i in range(1, CONFIG.ITER_SIZE + 1):
            try:
                data, target = next(loader_iter)
            except:
                loader_iter = iter(loader)
                data, target = next(loader_iter)

            # Image
            data = data.to(device)

            # Propagate forward
            outputs = model(data)

            # Loss
            loss = 0
            for output in outputs:
                # Resize target for {100%, 75%, 50%, Max} outputs
                target_ = resize_target(target, output.size(2))
                classmap = class_to_target(target_, CONFIG.N_CLASSES)
                target_ = label_bluring(classmap)  # soft crossEntropy target
                target_ = torch.from_numpy(target_).float()
                target_ = target_.to(device)
                # Compute crossentropy loss
                loss += criterion(output, target_)

            # Backpropagate (just compute gradients wrt the loss)
            loss /= float(CONFIG.ITER_SIZE)
            loss.backward()

            iter_loss += float(loss)

        # Update weights with accumulated gradients
        optimizer.step()
        # Visualizer and Summery Writer
        if iteration % CONFIG.ITER_TF == 0:
            print("itr {}, loss is {}".format(iteration, iter_loss), file=open(CONFIG.LOGNAME, "a"))  #
            # print("time taken for each iter is %.3f" % ((time.time() - iter_start_time)/iteration))
            # vis.drawLine(torch.FloatTensor([iteration]), torch.FloatTensor([iter_loss]))
            # vis.displayImg(inputImgTransBack(data), classToRGB(outputs[3][0].to("cpu").max(0)[1]),
            #                classToRGB(target[0].to("cpu")))
        # Save a model
        if iteration % CONFIG.ITER_SNAP == 0:
            save_network(CONFIG.SAVE_DIR, model, "SateDeepLab", iteration)

        # Save a model
        if iteration % 100 == 0:
            save_network(CONFIG.SAVE_DIR, model, "SateDeepLab", "latest")

    save_network(CONFIG.SAVE_DIR, model, "SateDeepLab", "final")
예제 #6
0
파일: train.py 프로젝트: BaiiYuan/ADL-Final
from options.train_options import TrainOptions
from dataloader.data_loader import dataloader
from model import create_model
from utils.visualizer import Visualizer

if __name__ == '__main__':
    # get training options
    opt = TrainOptions().parse()
    # create a dataset
    dataset = dataloader(opt)
    dataset_size = len(dataset) * opt.batchSize
    print('training images = %d' % dataset_size)
    # create a model
    model = create_model(opt)
    # create a visualizer
    visualizer = Visualizer(opt)
    # training flag
    keep_training = True
    max_iteration = opt.niter + opt.niter_decay
    epoch = 0
    total_iteration = opt.iter_count

    # training process
    while (keep_training):
        epoch_start_time = time.time()
        epoch += 1
        print('\n Training epoch: %d' % epoch)

        for i, data in enumerate(dataset):
            iter_start_time = time.time()
            total_iteration += 1
예제 #7
0
class TrainEngine:
    def __init__(self, opt):
        self.opt = opt
        self.cudnn_init()
        self.model = simple_model.SimpleModel(opt)
        self.train_dataloader = None
        self.val_dataloader = None
        self.visualizer = Visualizer(opt, False)

    def set_data(self, train_dataloader, val_dataloader):
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader

    def cudnn_init(self):
        # Set seed
        torch.backends.cudnn.benchmark = True
        random.seed(self.opt.seed)
        torch.manual_seed(self.opt.seed)
        torch.cuda.manual_seed(self.opt.seed)
        np.random.seed(self.opt.seed)

    def evaluate(self, dataloader, epoch=0):
        epoch_acc = {
            'DICE': tnt.meter.AverageValueMeter(),
            'IoU': tnt.meter.AverageValueMeter()
        }

        self.model.set_eval()
        with torch.no_grad():
            t = tqdm(dataloader)
            dice = []
            iou = []
            for batch_itr, data in enumerate(t):
                self.model.set_input(data)
                self.model.forward()
                # acc, loss = self.model.get_val()
                # for key in epoch_acc.keys():
                #     epoch_acc[key].add(acc[key])
                if self.opt.model == 'ICNet':
                    output_classes = self.model.est_mask[0].data.cpu().numpy(
                    ).argmax(axis=1)
                else:
                    output_classes = self.model.est_mask.data.cpu().numpy(
                    ).argmax(axis=1)
                target_classes = self.model.real_mask.data.cpu().numpy()
                dice += [general_dice(target_classes, output_classes)]
                iou += [general_jaccard(target_classes, output_classes)]
                t.set_description('[Testing]')
                average_dices = np.mean(dice)
                average_iou = np.mean(iou)
                t.set_postfix(DICE=average_dices, IoU=average_iou)
        self.visualizer.add_log('[Testing]:DICE:%f, IoU:%f' %
                                (average_dices, average_iou))
        # t.set_postfix(DICE=epoch_acc['DICE'].mean, IoU=epoch_acc['IoU'].mean)
        # self.visualizer.add_log('[Testing]:DICE:%f, IoU:%f' % (epoch_acc['DICE'].mean, epoch_acc['IoU'].mean))
        self.visualizer.save_images(self.model.get_current_visuals(), epoch)
        self.model.set_train()
        return epoch_acc

    def train_model(self):

        training_time = 0.0
        for cur_iter in range(0, self.opt.niter):
            running_loss = 0.0
            tic = time.time()

            t = tqdm(self.train_dataloader)
            batch_accum = 0
            for batch_itr, data in enumerate(t):
                self.model.set_input(data)
                self.model.optimize_parameters()
                running_loss += self.model.get_loss().item()
                batch_accum += data[0].size(0)
                t.set_description('[Training Epoch %d/%d]' %
                                  (cur_iter, self.opt.niter))
                t.set_postfix(loss=running_loss / batch_accum)
            self.model.scheduler.step()
            self.visualizer.plot_errors(
                {'train': running_loss / len(self.train_dataloader.dataset)},
                main_fig_title='err')
            self.visualizer.add_log('[Training Epoch %d/%d]:%f' %
                                    (cur_iter, self.opt.niter, running_loss /
                                     len(self.train_dataloader.dataset)))
            training_time += time.time() - tic

            if cur_iter % self.opt.evaluate_freq == 0:
                acc_metric = self.evaluate(self.val_dataloader, epoch=cur_iter)
                for key in acc_metric.keys():
                    self.visualizer.plot_errors({'test': acc_metric[key].mean},
                                                main_fig_title=key)

            if cur_iter % self.opt.save_epoch_freq == 0:
                print('saving the model at the end of epoch %d' % cur_iter)
                self.model.save('latest')
                self.model.save(cur_iter)
예제 #8
0
def main():
    # load data
    train_loader = torch.utils.data.DataLoader(NYUDepthDataset(
        cfg.trainval_data_root,
        'train',
        sample_num=cfg.sample_num,
        superpixel=False,
        relative=False,
        transform=True),
                                               batch_size=cfg.batch_size,
                                               shuffle=True,
                                               num_workers=cfg.num_workers,
                                               drop_last=True)
    print('Train Batches:', len(train_loader))

    # val_loader = torch.utils.data.DataLoader(NYUDepthDataset(cfg.trainval_data_root, 'val', transform=True),
    #                                          batch_size=cfg.batch_size, shuffle=True,
    #                                          num_workers=cfg.num_workers, drop_last=True)
    # print('Validation Batches:', len(val_loader))

    test_set = NyuDepthMat(
        cfg.test_data_root,
        '/home/ans/PycharmProjects/SDFCN/data/testIdxs.txt')
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=cfg.batch_size,
                                              shuffle=True,
                                              drop_last=True)

    # train_set = NyuDepthMat(cfg.test_data_root, '/home/ans/PycharmProjects/SDFCN/data/trainIdxs.txt')
    # train_loader = torch.utils.data.DataLoader(train_set,
    #                                           batch_size=cfg.batch_size,
    #                                           shuffle=True, drop_last=True)
    # train_loader = test_loader
    #
    val_loader = test_loader
    # load model and weight
    # model = FCRN(cfg.batch_size)
    model = ResDUCNet(model=torchvision.models.resnet50(pretrained=False))
    init_upsample = False
    # print(model)

    loss_fn = berHu()

    if cfg.use_gpu:
        print('Use CUDA')
        model = model.cuda()
        # loss_fn = berHu().cuda()
        # loss_fn = torch.nn.MSELoss().cuda()
        loss_fn = torch.nn.L1Loss().cuda()

    start_epoch = 0
    best_val_err = 10e3

    if cfg.resume_from_file:
        if os.path.isfile(cfg.resume_file):
            print("=> loading checkpoint '{}'".format(cfg.resume_file))
            checkpoint = torch.load(cfg.resume_file)
            # start_epoch = checkpoint['epoch']
            start_epoch = 0
            # model.load_state_dict(checkpoint['state_dict'])
            model.load_state_dict(checkpoint['model_state'])
            # print("=> loaded checkpoint '{}' (epoch {})"
            #       .format(cfg.resume_file, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(cfg.resume_file))
    # else:
    #     if init_upsample:
    #         print('Loading weights from ', cfg.weights_file)
    #         # bone_state_dict = load_weights(model, cfg.weights_file, dtype)
    #         model.load_state_dict(load_weights(model, cfg.weights_file, dtype))
    #     else:
    #         print('Loading weights from ', cfg.resnet50_file)
    #         pretrained_dict = torch.load(cfg.resnet50_file)
    #         model_dict = model.state_dict()
    #         pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    #         model_dict.update(pretrained_dict)
    #         model.load_state_dict(model_dict)
    #     print('Weights loaded.')

    # val_error, val_rmse = validate(val_loader, model, loss_fn)
    # print('before train: val_error %f, rmse: %f' % (val_error, val_rmse))

    vis = Visualizer(cfg.env)
    # 4.Optim
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
    print("optimizer set.")
    scheduler = lr_scheduler.StepLR(optimizer,
                                    step_size=cfg.step,
                                    gamma=cfg.lr_decay)

    for epoch in range(cfg.num_epochs):

        scheduler.step()
        # print(optimizer.state_dict()['param_groups'][0]['lr'])
        print('Starting train epoch %d / %d, lr=%f' %
              (start_epoch + epoch + 1, cfg.num_epochs,
               optimizer.state_dict()['param_groups'][0]['lr']))

        model.train()
        running_loss = 0
        count = 0
        epoch_loss = 0

        for i_batch, sample_batched in enumerate(train_loader):
            input_var = Variable(sample_batched['rgb'].type(dtype))
            depth_var = Variable(sample_batched['depth'].type(dtype))

            optimizer.zero_grad()
            output = model(input_var)
            loss = loss_fn(output, depth_var)

            if i_batch % cfg.print_freq == cfg.print_freq - 1:
                print('{0} batches, loss:{1}'.format(i_batch + 1,
                                                     loss.data.cpu().item()))
                vis.plot('loss', loss.data.cpu().item())

            if i_batch % (cfg.print_freq * 10) == (cfg.print_freq * 10) - 1:
                vis.depth('pred', output)
                # vis.imshow('img', sample_batched['rgb'].type(dtype))
                vis.depth('depth', sample_batched['depth'].type(dtype))

            count += 1
            running_loss += loss.data.cpu().numpy()

            loss.backward()
            optimizer.step()

        epoch_loss = running_loss / count
        print('epoch loss:', epoch_loss)

        val_error, val_rmse = validate(val_loader, model, loss_fn, vis=vis)
        vis.plot('val_error', val_error)
        vis.plot('val_rmse', val_rmse)
        vis.log('epoch:{epoch},lr={lr},epoch_loss:{loss},val_error:{val_cm}'.
                format(epoch=start_epoch + epoch + 1,
                       loss=epoch_loss,
                       val_cm=val_error,
                       lr=optimizer.state_dict()['param_groups'][0]['lr']))

        if val_error < best_val_err:
            best_val_err = val_error
            if not os.path.exists(cfg.checkpoint_dir):
                os.mkdir(cfg.checkpoint_dir)

            torch.save(
                {
                    'epoch': start_epoch + epoch + 1,
                    'state_dict': model.state_dict(),
                    # 'optimitezer': optimizer.state_dict(),
                },
                os.path.join(
                    cfg.checkpoint_dir,
                    '{}_{}_epoch_{}_{}'.format(cfg.checkpoint, cfg.env,
                                               start_epoch + epoch + 1,
                                               cfg.checkpoint_postfix)))

    torch.save(
        {
            'epoch': start_epoch + epoch + 1,
            'state_dict': model.state_dict(),
            # 'optimitezer': optimizer.state_dict(),
        },
        os.path.join(
            cfg.checkpoint_dir,
            '{}_{}_epoch_{}_{}'.format(cfg.checkpoint, cfg.env,
                                       start_epoch + epoch + 1,
                                       cfg.checkpoint_postfix)))
예제 #9
0
def main():
    # load data
    train_loader = torch.utils.data.DataLoader(NYUDepthDataset(
        cfg.trainval_data_root,
        'train',
        sample_num=cfg.sample_num,
        superpixel=False,
        relative=True,
        transform=True),
                                               batch_size=cfg.batch_size,
                                               shuffle=True,
                                               num_workers=cfg.num_workers,
                                               drop_last=True)
    print('Train Batches:', len(train_loader))

    # val_loader = torch.utils.data.DataLoader(NYUDepthDataset(cfg.trainval_data_root, 'val', transform=True),
    #                                          batch_size=cfg.batch_size, shuffle=True,
    #                                          num_workers=cfg.num_workers, drop_last=True)
    # print('Validation Batches:', len(val_loader))

    test_set = NyuDepthMat(
        cfg.test_data_root,
        '/home/ans/PycharmProjects/SDFCN/data/testIdxs.txt')
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=cfg.batch_size,
                                              shuffle=True,
                                              drop_last=True)

    # train_set = NyuDepthMat(cfg.test_data_root, '/home/ans/PycharmProjects/SDFCN/data/trainIdxs.txt')
    # train_loader = torch.utils.data.DataLoader(train_set,
    #                                           batch_size=cfg.batch_size,
    #                                           shuffle=True, drop_last=True)
    # train_loader = test_loader
    #
    val_loader = test_loader
    # load model and weight
    # model = FCRN(cfg.batch_size)
    model = DUCNet(model=torchvision.models.resnet50(pretrained=True))
    init_upsample = False
    # print(model)

    # loss_fn = berHu()

    if cfg.use_gpu:
        print('Use CUDA')
        model = model.cuda()
        berhu_loss = berHu().cuda()
        rela_loss = relativeloss().cuda()
        loss_fn = torch.nn.MSELoss().cuda()
    else:
        exit(0)

    start_epoch = 0
    # resume_from_file = False
    best_val_err = 10e3

    vis = Visualizer(cfg.env)
    print('Created visdom environment:', cfg.env)
    # 4.Optim
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
    print("optimizer set.")
    scheduler = lr_scheduler.StepLR(optimizer, step_size=cfg.step, gamma=0.1)

    for epoch in range(cfg.num_epochs):

        scheduler.step()
        print('Starting train epoch %d / %d, lr=%f' %
              (start_epoch + epoch + 1, cfg.num_epochs,
               optimizer.state_dict()['param_groups'][0]['lr']))

        model.train()
        running_loss = 0
        count = 0
        epoch_loss = 0

        for i_batch, sample_batched in enumerate(train_loader):
            input_var = Variable(sample_batched['rgb'].type(dtype))
            depth_var = Variable(sample_batched['depth'].type(dtype))

            optimizer.zero_grad()
            output = model(input_var)
            # loss = loss_fn(output, depth_var)
            loss1 = loss_fn(output, depth_var)
            Ah, Aw, Bh, Bw = generate_relative_pos(sample_batched['center'])

            loss2 = rela_loss(output[..., 0, Ah, Aw], output[..., 0, Bh, Bw],
                              sample_batched['ord'])
            loss = loss1 + loss2

            if i_batch % cfg.print_freq == cfg.print_freq - 1:
                print('{0} batches, loss:{1}, berhu:{2}, relative:{3}'.format(
                    i_batch + 1,
                    loss.data.cpu().item(),
                    loss1.data.cpu().item(),
                    loss2.data.cpu().item()))
                vis.plot('loss', loss.data.cpu().item())

            if i_batch % (cfg.print_freq * 10) == (cfg.print_freq * 10) - 1:
                vis.depth('pred', output)
                # vis.imshow('img', sample_batched['rgb'].type(dtype))
                vis.depth('depth', sample_batched['depth'].type(dtype))

            count += 1
            running_loss += loss.data.cpu().numpy()

            loss.backward()
            optimizer.step()

        epoch_loss = running_loss / count
        print('epoch loss:', epoch_loss)

        val_error, val_rmse = validate(val_loader, model, loss_fn, vis=vis)
        vis.plot('val_error', val_error)
        vis.plot('val_rmse', val_rmse)
        vis.log('epoch:{epoch},lr={lr},epoch_loss:{loss},val_error:{val_cm}'.
                format(epoch=start_epoch + epoch + 1,
                       loss=epoch_loss,
                       val_cm=val_error,
                       lr=optimizer.state_dict()['param_groups'][0]['lr']))

        if val_error < best_val_err:
            best_val_err = val_error
            if not os.path.exists(cfg.checkpoint_dir):
                os.mkdir(cfg.checkpoint_dir)

            torch.save(
                {
                    'epoch': start_epoch + epoch + 1,
                    'state_dict': model.state_dict(),
                    # 'optimitezer': optimizer.state_dict(),
                },
                os.path.join(
                    cfg.checkpoint_dir,
                    '{}_{}_epoch_{}_{}'.format(cfg.checkpoint, cfg.env,
                                               start_epoch + epoch + 1,
                                               cfg.checkpoint_postfix)))

    torch.save(
        {
            'epoch': start_epoch + epoch + 1,
            'state_dict': model.state_dict(),
            # 'optimitezer': optimizer.state_dict(),
        },
        os.path.join(
            cfg.checkpoint_dir,
            '{}_{}_epoch_{}_{}'.format(cfg.checkpoint, cfg.env,
                                       start_epoch + epoch + 1,
                                       cfg.checkpoint_postfix)))
예제 #10
0
def visualize_from_pred_dict(pred_eval_dict=None,
                             pred_eval_dict_path=None,
                             refvg_split=None,
                             out_path=None,
                             pred_bin_tags=None,
                             pred_score_tags=None,
                             pred_box_tags=None,
                             all_task_num=40,
                             subset_task_num=20,
                             gt_skip_exist=True,
                             pred_skip_exist=True,
                             verbose=True):
    if pred_eval_dict is None:
        predictions = np.load(pred_eval_dict_path, allow_pickle=True).item()
        assert isinstance(predictions, dict)
    else:
        predictions = pred_eval_dict

    # sample
    if subset_task_num > 0:
        subsets = subset_utils.subsets
    else:
        subsets = ['all']
    subset_dict = dict()
    for s in subsets:
        subset_dict[s] = list()

    for img_id, img_pred in predictions.items():
        for task_id, pred in img_pred.items():
            if 'subsets' in pred:
                for subset in pred['subsets']:
                    if subset in subset_dict:
                        subset_dict[subset].append((img_id, task_id))
            else:
                pred['subsets'] = ['all']
                subset_dict['all'].append((img_id, task_id))

    to_plot = list()
    for subset in subset_utils.subsets:
        if subset not in subset_dict:
            continue
        img_task_ids = subset_dict[subset]
        if subset == 'all':
            sample_num = all_task_num
        else:
            sample_num = subset_task_num
        if len(img_task_ids) > sample_num:
            img_task_ids = random.sample(img_task_ids, sample_num)
        to_plot += img_task_ids

    # plot
    visualizer = Visualizer(refvg_split=refvg_split,
                            pred_plot_path=os.path.join(
                                out_path, 'pred_plots'),
                            pred_skip_exist=pred_skip_exist,
                            gt_skip_exist=gt_skip_exist)
    for img_id, task_id in to_plot:
        visualizer.plot_single_task(img_id,
                                    task_id,
                                    predictions[img_id][task_id],
                                    pred_bin_tags=pred_bin_tags,
                                    pred_score_tags=pred_score_tags,
                                    pred_box_tags=pred_box_tags,
                                    verbose=verbose)

    # generate html
    html_path = os.path.join(out_path, 'htmls')
    result_path = os.path.join(out_path, 'results.txt')
    if not os.path.exists(result_path):
        result_path = None

    visualizer.generate_html(html_path,
                             enable_subsets=subset_task_num > 0,
                             result_txt_path=result_path)
    return
예제 #11
0
from data import create_dataset
from models import create_model
from utils.visualizer import Visualizer

if __name__ == '__main__':
    opt = TrainOptions().parse()  # get training options
    dataset = create_dataset(
        opt)  # create a dataset given opt.dataset_mode and other options
    dataset_size = len(dataset)  # get the number of images in the dataset.
    print('The number of training images = %d' % dataset_size)

    model = create_model(
        opt)  # create a model given opt.model and other options
    model.setup(
        opt)  # regular setup: load and print networks; create schedulers
    visualizer = Visualizer(
        opt)  # create a visualizer that display/save images and plots
    total_iters = 0  # the total number of training iterations

    for epoch in range(
            model.epoch_count, model.n_epochs + model.n_epochs_decay + 1
    ):  # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
        epoch_start_time = time.time()  # timer for entire epoch
        iter_data_time = time.time()  # timer for data loading per iteration
        epoch_iter = 0  # the number of training iterations in current epoch, reset to 0 every epoch
        visualizer.reset(
        )  # reset the visualizer: make sure it saves the results to HTML at least once every epoch

        for i, data in enumerate(dataset):  # inner loop within one epoch
            iter_start_time = time.time(
            )  # timer for computation per iteration
            if total_iters % opt.print_freq == 0:
예제 #12
0
import cfgs.config as cfg
from random import randint
from utils.visualizer import Visualizer



'''
默认的初始化过程
('voc_2007_trainval', '/home/qlt/qiulingteng/detection/yolo2-pytorch-master/data', 16, <function preprocess_train at 0x7f61c510de60>)
[array([320, 320]), array([352, 352]), array([384, 384]), array([416, 416]), array([448, 448]), array([480, 480]), array([512, 512]), array([544, 544]), array([576, 576])]
'''
# data loader
imdb = VOCDataset(cfg.imdb_train, cfg.DATA_DIR, cfg.train_batch_size,
                  yolo_utils.preprocess_train, processes=1, shuffle=True,
                  dst_size=cfg.multi_scale_inp_size)
viz = Visualizer()
# dst_size=cfg.inp_size)
print('load data succ...')

net = Darknet19()
# net_utils.load_net(cfg.trained_model, net)
# pretrained_model = os.path.join(cfg.train_output_dir,
#     'darknet19_voc07trainval_exp1_63.h5')
# pretrained_model = cfg.trained_model
# net_utils.load_net(pretrained_model, net)
net.load_from_npz(cfg.pretrained_model, num_conv=18)
net.cuda()
net.train()
print('load net succ...')

print("For this training we have follow para:\n"
예제 #13
0
    def train(self,
              edgenetpath=None,
              sr2x1_path=None,
              sr2x2_path=None,
              srcnn_path=None,
              srresnet_path=None,
              is_fine_tune=False,
              random_scale=True,
              rotate=True,
              fliplr=True,
              fliptb=True):
        vis = Visualizer(self.env)

        print('================ Loading datasets =================')
        # load training dataset
        print('## Current Mode: Train')
        # train_data_loader = self.load_dataset(mode='valid')
        train_data_loader = self.load_dataset(mode='train',
                                              random_scale=random_scale,
                                              rotate=rotate,
                                              fliplr=fliplr,
                                              fliptb=fliptb)

        t_save_dir = 'results/train_result/' + self.train_dataset + "_{}"
        if not os.path.exists(t_save_dir.format("origin")):
            os.makedirs(t_save_dir.format("origin"))
        if not os.path.exists(t_save_dir.format("lr4x")):
            os.makedirs(t_save_dir.format("lr4x"))
        if not os.path.exists(t_save_dir.format("srunit_2x")):
            os.makedirs(t_save_dir.format("srunit_2x"))
        if not os.path.exists(t_save_dir.format("bicubic")):
            os.makedirs(t_save_dir.format("bicubic"))
        if not os.path.exists(t_save_dir.format("bicubic2x")):
            os.makedirs(t_save_dir.format("bicubic2x"))
        if not os.path.exists(t_save_dir.format("srunit_common")):
            os.makedirs(t_save_dir.format("srunit_common"))
        if not os.path.exists(t_save_dir.format("srunit_2xbicubic")):
            os.makedirs(t_save_dir.format("srunit_2xbicubic"))
        if not os.path.exists(t_save_dir.format("srunit_4xbicubic")):
            os.makedirs(t_save_dir.format("srunit_4xbicubic"))
        if not os.path.exists(t_save_dir.format("srresnet")):
            os.makedirs(t_save_dir.format("srresnet"))
        if not os.path.exists(t_save_dir.format("srcnn")):
            os.makedirs(t_save_dir.format("srcnn"))

        ##########################################################
        ##################### build network ######################
        ##########################################################
        print('Building Networks and initialize parameters\' weights....')
        # init sr resnet
        srresnet2x1 = Upscale2xResnetGenerator(input_nc=3,
                                               output_nc=3,
                                               n_blocks=5,
                                               norm=NORM,
                                               activation='prelu',
                                               learn_residual=True)
        srresnet2x2 = Upscale2xResnetGenerator(input_nc=3,
                                               output_nc=3,
                                               n_blocks=5,
                                               norm=NORM,
                                               activation='prelu',
                                               learn_residual=True)
        srresnet2x1.apply(weights_init_normal)
        srresnet2x2.apply(weights_init_normal)

        # init srresnet
        srresnet = SRResnet()
        srresnet.apply(weights_init_normal)

        # init srcnn
        srcnn = SRCNN()
        srcnn.apply(weights_init_normal)

        # init discriminator
        discnet = NLayerDiscriminator(input_nc=3, ndf=64, n_layers=5)

        # init edgenet
        edgenet = HED_1L()
        if edgenetpath is None or not os.path.exists(edgenetpath):
            raise Exception('Invalid edgenet model')
        else:
            pretrained_dict = torch.load(edgenetpath)
            model_dict = edgenet.state_dict()
            pretrained_dict = {
                k: v
                for k, v in pretrained_dict.items() if k in model_dict
            }
            model_dict.update(pretrained_dict)
            edgenet.load_state_dict(model_dict)

        # init vgg feature
        featuremapping = VGGFeatureMap(models.vgg19(pretrained=True))

        # load pretrained srresnet or just initialize
        if sr2x1_path is None or not os.path.exists(sr2x1_path):
            print('===> initialize the srresnet2x1')
            print('======> No pretrained model')
        else:
            print('======> loading the weight from pretrained model')
            # deblurnet.load_state_dict(torch.load(sr2x1_path))
            pretrained_dict = torch.load(sr2x1_path)
            model_dict = srresnet2x1.state_dict()

            pretrained_dict = {
                k: v
                for k, v in pretrained_dict.items() if k in model_dict
            }
            model_dict.update(pretrained_dict)
            srresnet2x1.load_state_dict(model_dict)

        if sr2x2_path is None or not os.path.exists(sr2x2_path):
            print('===> initialize the srresnet2x2')
            print('======> No pretrained model')
        else:
            print('======> loading the weight from pretrained model')
            # deblurnet.load_state_dict(torch.load(sr2x2_path))
            pretrained_dict = torch.load(sr2x2_path)
            model_dict = srresnet2x2.state_dict()

            pretrained_dict = {
                k: v
                for k, v in pretrained_dict.items() if k in model_dict
            }
            model_dict.update(pretrained_dict)
            srresnet2x2.load_state_dict(model_dict)

        if srresnet_path is None or not os.path.exists(srresnet_path):
            print('===> initialize the srcnn')
            print('======> No pretrained model')
        else:
            print('======> loading the weight from pretrained model')
            pretrained_dict = torch.load(srresnet_path)
            model_dict = srresnet.state_dict()

            pretrained_dict = {
                k: v
                for k, v in pretrained_dict.items() if k in model_dict
            }
            model_dict.update(pretrained_dict)
            srresnet.load_state_dict(model_dict)

        if srcnn_path is None or not os.path.exists(srcnn_path):
            print('===> initialize the srcnn')
            print('======> No pretrained model')
        else:
            print('======> loading the weight from pretrained model')
            pretrained_dict = torch.load(srcnn_path)
            model_dict = srcnn.state_dict()

            pretrained_dict = {
                k: v
                for k, v in pretrained_dict.items() if k in model_dict
            }
            model_dict.update(pretrained_dict)
            srcnn.load_state_dict(model_dict)

        # optimizer init
        # different learning rate
        lr = self.lr

        srresnet2x1_optimizer = optim.Adam(srresnet2x1.parameters(),
                                           lr=lr,
                                           betas=(0.9, 0.999))
        srresnet2x2_optimizer = optim.Adam(srresnet2x2.parameters(),
                                           lr=lr,
                                           betas=(0.9, 0.999))
        srresnet_optimizer = optim.Adam(srresnet.parameters(),
                                        lr=lr,
                                        betas=(0.9, 0.999))
        srcnn_optimizer = optim.Adam(srcnn.parameters(),
                                     lr=lr,
                                     betas=(0.9, 0.999))
        disc_optimizer = optim.Adam(discnet.parameters(),
                                    lr=lr / 10,
                                    betas=(0.9, 0.999))

        # loss function init
        MSE_loss = nn.MSELoss()
        BCE_loss = nn.BCELoss()

        # cuda accelerate
        if USE_GPU:
            edgenet.cuda()
            srresnet2x1.cuda()
            srresnet2x2.cuda()
            srresnet.cuda()
            srcnn.cuda()
            discnet.cuda()
            featuremapping.cuda()
            MSE_loss.cuda()
            BCE_loss.cuda()
            print('\tCUDA acceleration is available.')

        ##########################################################
        ##################### train network ######################
        ##########################################################
        import torchnet as tnt
        from tqdm import tqdm
        from PIL import Image

        batchnorm = nn.BatchNorm2d(1).cuda()
        upsample = nn.Upsample(scale_factor=2, mode='bilinear')

        edge_avg_loss = tnt.meter.AverageValueMeter()
        total_avg_loss = tnt.meter.AverageValueMeter()
        disc_avg_loss = tnt.meter.AverageValueMeter()
        psnr_2x_avg = tnt.meter.AverageValueMeter()
        ssim_2x_avg = tnt.meter.AverageValueMeter()
        psnr_4x_avg = tnt.meter.AverageValueMeter()
        ssim_4x_avg = tnt.meter.AverageValueMeter()

        psnr_bicubic_avg = tnt.meter.AverageValueMeter()
        ssim_bicubic_avg = tnt.meter.AverageValueMeter()
        psnr_2xcubic_avg = tnt.meter.AverageValueMeter()
        ssim_2xcubic_avg = tnt.meter.AverageValueMeter()
        psnr_4xcubic_avg = tnt.meter.AverageValueMeter()
        ssim_4xcubic_avg = tnt.meter.AverageValueMeter()

        psnr_srresnet_avg = tnt.meter.AverageValueMeter()
        ssim_srresnet_avg = tnt.meter.AverageValueMeter()

        psnr_srcnn_avg = tnt.meter.AverageValueMeter()
        ssim_srcnn_avg = tnt.meter.AverageValueMeter()

        srresnet2x1.train()
        srresnet2x2.train()
        srresnet.train()
        srcnn.train()
        discnet.train()
        itcnt = 0
        for epoch in range(self.num_epochs):
            psnr_2x_avg.reset()
            ssim_2x_avg.reset()
            psnr_4x_avg.reset()
            ssim_4x_avg.reset()
            psnr_bicubic_avg.reset()
            ssim_bicubic_avg.reset()
            psnr_2xcubic_avg.reset()
            ssim_2xcubic_avg.reset()
            psnr_4xcubic_avg.reset()
            ssim_4xcubic_avg.reset()
            psnr_srresnet_avg.reset()
            ssim_srresnet_avg.reset()
            psnr_srcnn_avg.reset()
            ssim_srcnn_avg.reset()

            # learning rate is decayed by a factor every 20 epoch
            if (epoch + 1 % 20) == 0:
                for param_group in srresnet2x1_optimizer.param_groups:
                    param_group["lr"] /= 10.0
                print("Learning rate decay for srresnet2x1: lr={}".format(
                    srresnet2x1_optimizer.param_groups[0]["lr"]))
                for param_group in srresnet2x2_optimizer.param_groups:
                    param_group["lr"] /= 10.0
                print("Learning rate decay for srresnet2x2: lr={}".format(
                    srresnet2x2_optimizer.param_groups[0]["lr"]))
                for param_group in srresnet_optimizer.param_groups:
                    param_group["lr"] /= 10.0
                print("Learning rate decay for srresnet: lr={}".format(
                    srresnet_optimizer.param_groups[0]["lr"]))
                for param_group in srcnn_optimizer.param_groups:
                    param_group["lr"] /= 10.0
                print("Learning rate decay for srcnn: lr={}".format(
                    srcnn_optimizer.param_groups[0]["lr"]))
                for param_group in disc_optimizer.param_groups:
                    param_group["lr"] /= 10.0
                print("Learning rate decay for discnet: lr={}".format(
                    disc_optimizer.param_groups[0]["lr"]))

            itbar = tqdm(enumerate(train_data_loader))
            for ii, (hr, lr2x, lr4x, bc2x, bc4x) in itbar:

                mini_batch = hr.size()[0]

                hr_ = Variable(hr)
                lr2x_ = Variable(lr2x)
                lr4x_ = Variable(lr4x)
                bc2x_ = Variable(bc2x)
                bc4x_ = Variable(bc4x)
                real_label = Variable(torch.ones(mini_batch))
                fake_label = Variable(torch.zeros(mini_batch))

                # cuda mode setting
                if USE_GPU:
                    hr_ = hr_.cuda()
                    lr2x_ = lr2x_.cuda()
                    lr4x_ = lr4x_.cuda()
                    bc2x_ = bc2x_.cuda()
                    bc4x_ = bc4x_.cuda()
                    real_label = real_label.cuda()
                    fake_label = fake_label.cuda()

                # =============================================================== #
                # ================ Edge-based srresnet training ================= #
                # =============================================================== #
                sr2x_ = srresnet2x1(lr4x_)
                sr4x_ = srresnet2x2(lr2x_)
                bc2x_sr4x_ = srresnet2x2(bc2x_)
                sr2x_bc4x_ = upsample(sr2x_)
                '''===================== Train Discriminator ====================='''
                if epoch + 1 > self.pretrain_epochs:
                    disc_optimizer.zero_grad()

                    #===== 2x disc loss =====#
                    real_decision_2x = discnet(lr2x_)
                    real_loss_2x = BCE_loss(real_decision_2x,
                                            real_label.detach())

                    fake_decision_2x = discnet(sr2x_.detach())
                    fake_loss_2x = BCE_loss(fake_decision_2x,
                                            fake_label.detach())

                    disc_loss_2x = real_loss_2x + fake_loss_2x

                    disc_loss_2x.backward()
                    disc_optimizer.step()

                    #===== 4x disc loss =====#
                    real_decision_4x = discnet(hr_)
                    real_loss_4x = BCE_loss(real_decision_4x,
                                            real_label.detach())

                    fake_decision_4x = discnet(sr4x_.detach())
                    fake_loss_4x = BCE_loss(fake_decision_4x,
                                            fake_label.detach())

                    disc_loss_4x = real_loss_4x + fake_loss_4x

                    disc_loss_4x.backward()
                    disc_optimizer.step()

                    disc_avg_loss.add(
                        (disc_loss_2x + disc_loss_4x).data.item())
                '''=================== Train srresnet Generator ==================='''
                edge_trade_off = [0.7, 0.2, 0.1, 0.05, 0.01, 0.3]
                if epoch + 1 > self.pretrain_epochs:
                    a1, a2, a3 = 0.55, 0.1, 0.75
                else:
                    a1, a2, a3 = 0.65, 0.0, 0.95

                #============ calculate 2x loss ==============#
                srresnet2x1_optimizer.zero_grad()

                #### Edgenet Loss ####
                pred = edgenet(sr2x_)
                real = edgenet(lr2x_)

                edge_loss_2x = BCE_loss(pred.detach(), real.detach())
                # for i in range(6):
                #     edge_loss_2x += edge_trade_off[i] * \
                #         BCE_loss(pred[i].detach(), real[i].detach())
                # edge_loss = 0.7 * BCE2d(pred[0], real[i]) + 0.3 * BCE2d(pred[5], real[i])

                #### Content Loss ####
                content_loss_2x = MSE_loss(
                    sr2x_, lr2x_)  #+ 0.1*BCE_loss(1-sr2x_, 1-lr2x_)

                #### Perceptual Loss ####
                real_feature = featuremapping(lr2x_)
                fake_feature = featuremapping(sr2x_)
                vgg_loss_2x = MSE_loss(fake_feature, real_feature.detach())

                #### Adversarial Loss ####
                advs_loss_2x = BCE_loss(
                    discnet(sr2x_),
                    real_label) if epoch + 1 > self.pretrain_epochs else 0

                #============ calculate scores ==============#
                psnr_2x_score_process = batch_compare_filter(
                    sr2x_.cpu().data, lr2x, PSNR)
                psnr_2x_avg.add(psnr_2x_score_process)

                ssim_2x_score_process = batch_compare_filter(
                    sr2x_.cpu().data, lr2x, SSIM)
                ssim_2x_avg.add(ssim_2x_score_process)

                #============== loss backward ===============#
                total_loss_2x = a1 * edge_loss_2x + a2 * advs_loss_2x + \
                    a3 * content_loss_2x + (1.0 - a3) * vgg_loss_2x

                total_loss_2x.backward()
                srresnet2x1_optimizer.step()

                #============ calculate 4x loss ==============#
                if is_fine_tune:
                    sr2x_ = srresnet2x1(lr4x_)
                    sr4x_ = srresnet2x2(sr2x_)

                srresnet2x2_optimizer.zero_grad()
                #### Edgenet Loss ####
                pred = edgenet(sr4x_)
                real = edgenet(hr_)

                # edge_loss_4x = 0
                edge_loss_4x = BCE_loss(pred.detach(), real.detach())
                # for i in range(6):
                #     edge_loss_4x += edge_trade_off[i] * \
                #         BCE_loss(pred[i].detach(), real[i].detach())
                # edge_loss = 0.7 * BCE2d(pred[0], real[i]) + 0.3 * BCE2d(pred[5], real[i])

                #### Content Loss ####
                content_loss_4x = MSE_loss(
                    sr4x_, hr_)  #+ 0.1*BCE_loss(1-sr4x_, 1-hr_)

                #### Perceptual Loss ####
                real_feature = featuremapping(hr_)
                fake_feature = featuremapping(sr4x_)
                vgg_loss_4x = MSE_loss(fake_feature, real_feature.detach())

                #### Adversarial Loss ####
                advs_loss_4x = BCE_loss(
                    discnet(sr4x_),
                    real_label) if epoch + 1 > self.pretrain_epochs else 0

                #============ calculate scores ==============#
                psnr_4x_score_process = batch_compare_filter(
                    sr4x_.cpu().data, hr, PSNR)
                psnr_4x_avg.add(psnr_4x_score_process)

                ssim_4x_score_process = batch_compare_filter(
                    sr4x_.cpu().data, hr, SSIM)
                ssim_4x_avg.add(ssim_4x_score_process)

                psnr_bicubic_score = batch_compare_filter(
                    bc4x_.cpu().data, hr, PSNR)
                psnr_bicubic_avg.add(psnr_bicubic_score)

                ssim_bicubic_score = batch_compare_filter(
                    bc4x_.cpu().data, hr, SSIM)
                ssim_bicubic_avg.add(ssim_bicubic_score)

                psnr_2xcubic_score = batch_compare_filter(
                    bc2x_sr4x_.cpu().data, hr, PSNR)
                psnr_2xcubic_avg.add(psnr_2xcubic_score)

                ssim_2xcubic_score = batch_compare_filter(
                    bc2x_sr4x_.cpu().data, hr, SSIM)
                ssim_2xcubic_avg.add(ssim_2xcubic_score)

                psnr_4xcubic_score = batch_compare_filter(
                    sr2x_bc4x_.cpu().data, hr, PSNR)
                psnr_4xcubic_avg.add(psnr_4xcubic_score)

                ssim_4xcubic_score = batch_compare_filter(
                    sr2x_bc4x_.cpu().data, hr, SSIM)
                ssim_4xcubic_avg.add(ssim_4xcubic_score)

                #============== loss backward ===============#
                total_loss_4x = a1 * edge_loss_4x + a2 * advs_loss_4x + \
                    a3 * content_loss_4x + (1.0 - a3) * vgg_loss_4x

                total_loss_4x.backward()
                srresnet2x2_optimizer.step()

                total_avg_loss.add((total_loss_2x + total_loss_4x).data.item())
                edge_avg_loss.add((edge_loss_2x + edge_loss_4x).data.item())
                if epoch + 1 > self.pretrain_epochs:
                    disc_avg_loss.add(
                        (advs_loss_2x + advs_loss_4x).data.item())

                if (ii + 1) % self.plot_iter == self.plot_iter - 1:
                    res = {
                        'edge loss': edge_avg_loss.value()[0],
                        'generate loss': total_avg_loss.value()[0],
                        'discriminate loss': disc_avg_loss.value()[0]
                    }
                    vis.plot_many(res, 'Deblur net Loss')

                    psnr_2x_score_origin = batch_compare_filter(
                        bc2x, lr2x, PSNR)
                    psnr_4x_score_origin = batch_compare_filter(bc4x, hr, PSNR)
                    res_psnr = {
                        '2x_origin_psnr': psnr_2x_score_origin,
                        '2x_sr_psnr': psnr_2x_score_process,
                        '4x_origin_psnr': psnr_4x_score_origin,
                        '4x_sr_psnr': psnr_4x_score_process
                    }
                    vis.plot_many(res_psnr, 'PSNR Score')

                    ssim_2x_score_origin = batch_compare_filter(
                        bc2x, lr2x, SSIM)
                    ssim_4x_score_origin = batch_compare_filter(bc4x, hr, SSIM)
                    res_ssim = {
                        '2x_origin_ssim': ssim_2x_score_origin,
                        '2x_sr_ssim': ssim_2x_score_process,
                        '4x_origin_ssim': ssim_4x_score_origin,
                        '4x_sr_ssim': ssim_4x_score_process
                    }
                    vis.plot_many(res_ssim, 'SSIM Score')

                save_img(
                    hr[0],
                    os.path.join(t_save_dir.format("origin"),
                                 "{}.jpg".format(ii)))
                save_img(
                    lr4x[0],
                    os.path.join(t_save_dir.format("lr4x"),
                                 "{}.jpg".format(ii)))
                save_img(
                    bc4x[0],
                    os.path.join(t_save_dir.format("bicubic"),
                                 "{}.jpg".format(ii)))
                save_img(
                    bc2x[0],
                    os.path.join(t_save_dir.format("bicubic2x"),
                                 "{}.jpg".format(ii)))
                save_img(
                    sr2x_.cpu().data[0],
                    os.path.join(t_save_dir.format("srunit_2x"),
                                 "{}.jpg".format(ii)))
                save_img(
                    sr4x_.cpu().data[0],
                    os.path.join(t_save_dir.format("srunit_common"),
                                 "{}.jpg".format(ii)))
                save_img(
                    bc2x_sr4x_.cpu().data[0],
                    os.path.join(t_save_dir.format("srunit_2xbicubic"),
                                 "{}.jpg".format(ii)))
                save_img(
                    sr2x_bc4x_.cpu().data[0],
                    os.path.join(t_save_dir.format("srunit_4xbicubic"),
                                 "{}.jpg".format(ii)))

                # =============================================================== #
                # ====================== srresnet training ====================== #
                # =============================================================== #
                sr4x_ = srresnet(lr4x_)

                #============ calculate 4x loss ==============#
                srresnet_optimizer.zero_grad()

                #### Content Loss ####
                content_loss_4x = MSE_loss(sr4x_, hr_)

                #### Perceptual Loss ####
                real_feature = featuremapping(hr_)
                fake_feature = featuremapping(sr4x_)
                vgg_loss_4x = MSE_loss(fake_feature, real_feature.detach())

                #============ calculate scores ==============#
                psnr_4x_score = batch_compare_filter(sr4x_.cpu().data, hr,
                                                     PSNR)
                psnr_srresnet_avg.add(psnr_4x_score)

                ssim_4x_score = batch_compare_filter(sr4x_.cpu().data, hr,
                                                     SSIM)
                ssim_srresnet_avg.add(ssim_4x_score)

                #============== loss backward ===============#
                total_loss_4x = content_loss_4x + 0.2 * vgg_loss_4x

                total_loss_4x.backward()
                srresnet_optimizer.step()

                save_img(
                    sr4x_.cpu().data[0],
                    os.path.join(t_save_dir.format("srresnet"),
                                 "{}.jpg".format(ii)))

                # =============================================================== #
                # ======================= srcnn training ======================== #
                # =============================================================== #
                sr4x_ = srcnn(bc4x_)

                #============ calculate 4x loss ==============#
                srcnn_optimizer.zero_grad()

                #### Content Loss ####
                content_loss_4x = MSE_loss(sr4x_, hr_)

                #============ calculate scores ==============#
                psnr_4x_score = batch_compare_filter(sr4x_.cpu().data, hr,
                                                     PSNR)
                psnr_srcnn_avg.add(psnr_4x_score)

                ssim_4x_score = batch_compare_filter(sr4x_.cpu().data, hr,
                                                     SSIM)
                ssim_srcnn_avg.add(ssim_4x_score)

                #============== loss backward ===============#
                total_loss_4x = content_loss_4x

                total_loss_4x.backward()
                srcnn_optimizer.step()

                save_img(
                    sr4x_.cpu().data[0],
                    os.path.join(t_save_dir.format("srcnn"),
                                 "{}.jpg".format(ii)))

                #======================= Output result of total training processing =======================#
                itcnt += 1
                itbar.set_description(
                    "Epoch: [%2d] [%d/%d] PSNR_2x_Avg: %.6f, SSIM_2x_Avg: %.6f, PSNR_4x_Avg: %.6f, SSIM_4x_Avg: %.6f"
                    % ((epoch + 1), (ii + 1), len(train_data_loader),
                       psnr_2x_avg.value()[0], ssim_2x_avg.value()[0],
                       psnr_4x_avg.value()[0], ssim_4x_avg.value()[0]))

                if (ii + 1) % self.plot_iter == self.plot_iter - 1:
                    # test_ = deblurnet(torch.cat([y_.detach(), x_edge], 1))
                    hr_edge = edgenet(hr_)
                    sr2x_edge = edgenet(sr2x_)
                    sr4x_edge = edgenet(sr4x_)

                    vis.images(hr_edge.cpu().data,
                               win='HR edge predict',
                               opts=dict(title='HR edge predict'))
                    vis.images(sr2x_edge.cpu().data,
                               win='SR2X edge predict',
                               opts=dict(title='SR2X edge predict'))
                    vis.images(sr4x_edge.cpu().data,
                               win='SR4X edge predict',
                               opts=dict(title='SR4X edge predict'))

                    sr4x_ = srresnet2x2(sr2x_)
                    vis.images(lr2x,
                               win='LR2X image',
                               opts=dict(title='LR2X image'))
                    vis.images(lr4x,
                               win='LR4X image',
                               opts=dict(title='LR4X image'))
                    vis.images(bc2x,
                               win='BC2X image',
                               opts=dict(title='BC2X image'))
                    vis.images(bc4x,
                               win='BC4X image',
                               opts=dict(title='BC4X image'))
                    vis.images(sr2x_.cpu().data,
                               win='SR2X image',
                               opts=dict(title='SR2X image'))
                    vis.images(sr4x_.cpu().data,
                               win='SR4X image',
                               opts=dict(title='SR4X image'))

                    vis.images(hr, win='HR image', opts=dict(title='HR image'))

                    res = {
                        "bicubic PSNR": psnr_bicubic_avg.value()[0],
                        "bicubic SSIM": ssim_bicubic_avg.value()[0],
                        "srunit4x PSNR": psnr_4x_avg.value()[0],
                        "srunit4x SSIM": ssim_4x_avg.value()[0],
                        "2xbicubic PSNR": psnr_2xcubic_avg.value()[0],
                        "2xbicubic SSIM": ssim_2xcubic_avg.value()[0],
                        "4xbicubic PSNR": psnr_4xcubic_avg.value()[0],
                        "4xbicubic SSIM": ssim_4xcubic_avg.value()[0],
                        "srresnet PSNR": psnr_srresnet_avg.value()[0],
                        "srresnet SSIM": ssim_srresnet_avg.value()[0],
                        "srcnn PSNR": psnr_srcnn_avg.value()[0],
                        "srcnn SSIM": ssim_srcnn_avg.value()[0]
                    }

                    vis.metrics(res, "metrics")

            if (epoch + 1) % self.save_epochs == 0:
                self.save_model(
                    srresnet2x1,
                    os.path.join(self.save_dir, 'checkpoints', 'srunitnet'),
                    'srnet2x1_param_batch{}_lr{}_epoch{}'.format(
                        self.batch_size, self.lr, epoch + 1))
                self.save_model(
                    srresnet2x2,
                    os.path.join(self.save_dir, 'checkpoints', 'srunitnet'),
                    'srnet2x2_param_batch{}_lr{}_epoch{}'.format(
                        self.batch_size, self.lr, epoch + 1))
                self.save_model(
                    srresnet,
                    os.path.join(self.save_dir, 'checkpoints', 'srresnet'),
                    'srresnet_param_batch{}_lr{}_epoch{}'.format(
                        self.batch_size, self.lr, epoch + 1))
                self.save_model(
                    srcnn, os.path.join(self.save_dir, 'checkpoints', 'srcnn'),
                    'srcnn_param_batch{}_lr{}_epoch{}'.format(
                        self.batch_size, self.lr, epoch + 1))

        # Save final trained model and results
        vis.save([self.env])
        self.save_model(
            srresnet2x1, os.path.join(self.save_dir, 'checkpoints',
                                      'srunitnet'),
            'srnet2x1_param_batch{}_lr{}_epoch{}'.format(
                self.batch_size, self.lr, self.num_epochs))
        self.save_model(
            srresnet2x2, os.path.join(self.save_dir, 'checkpoints',
                                      'srunitnet'),
            'srnet2x2_param_batch{}_lr{}_epoch{}'.format(
                self.batch_size, self.lr, self.num_epochs))
        self.save_model(
            srcnn, os.path.join(self.save_dir, 'checkpoints', 'srresnet'),
            'srresnet_param_batch{}_lr{}_epoch{}'.format(
                self.batch_size, self.lr, self.num_epochs))
        self.save_model(
            srcnn, os.path.join(self.save_dir, 'checkpoints', 'srcnn'),
            'srcnn_param_batch{}_lr{}_epoch{}'.format(self.batch_size, self.lr,
                                                      self.num_epochs))
예제 #14
0
def main():
    opt = TrainOptions().parse()
    if opt.sr_dir == '':
        print('sr directory is null.')
        exit()
    sr_pretrain_dir = os.path.join(
        opt.exp_dir, opt.exp_id, opt.sr_dir + '-' + opt.load_prefix_pose[0:-1])
    if not os.path.isdir(sr_pretrain_dir):
        os.makedirs(sr_pretrain_dir)
    train_history = ASNTrainHistory()
    # print(train_history.lr)
    # exit()
    checkpoint_agent = Checkpoint()
    visualizer = Visualizer(opt)
    visualizer.log_path = sr_pretrain_dir + '/' + 'log.txt'
    train_scale_path = sr_pretrain_dir + '/' + 'train_scales.txt'
    train_rotation_path = sr_pretrain_dir + '/' + 'train_rotations.txt'
    val_scale_path = sr_pretrain_dir + '/' + 'val_scales.txt'
    val_rotation_path = sr_pretrain_dir + '/' + 'val_rotations.txt'

    # with open(visualizer.log_path, 'a+') as log_file:
    #     log_file.write(opt.resume_prefix_pose + '.pth.tar\n')
    # lost_joint_count_path = os.path.join(opt.exp_dir, opt.exp_id, opt.astn_dir, 'joint-count.txt')
    # print("=> log saved to path '{}'".format(visualizer.log_path))
    # if opt.dataset == 'mpii':
    #     num_classes = 16
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id

    print('collecting training scale and rotation distributions ...\n')
    train_scale_distri = read_grnd_distri_from_txt(train_scale_path)
    train_rotation_distri = read_grnd_distri_from_txt(train_rotation_path)
    dataset = MPII('dataset/mpii-hr-lsp-normalizer.json',
                   '/bigdata1/zt53/data',
                   is_train=True,
                   grnd_scale_distri=train_scale_distri,
                   grnd_rotation_distri=train_rotation_distri)
    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=opt.bs,
                                               shuffle=True,
                                               num_workers=opt.nThreads,
                                               pin_memory=True)

    print('collecting validation scale and rotation distributions ...\n')
    val_scale_distri = read_grnd_distri_from_txt(val_scale_path)
    val_rotation_distri = read_grnd_distri_from_txt(val_rotation_path)
    dataset = MPII('dataset/mpii-hr-lsp-normalizer.json',
                   '/bigdata1/zt53/data',
                   is_train=False,
                   grnd_scale_distri=val_scale_distri,
                   grnd_rotation_distri=val_rotation_distri)
    val_loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=opt.bs,
                                             shuffle=False,
                                             num_workers=opt.nThreads,
                                             pin_memory=True)

    agent = model.create_asn(chan_in=256,
                             chan_out=256,
                             scale_num=len(dataset.scale_means),
                             rotation_num=len(dataset.rotation_means),
                             is_aug=True)
    agent = torch.nn.DataParallel(agent).cuda()
    optimizer = torch.optim.RMSprop(agent.parameters(),
                                    lr=opt.lr,
                                    alpha=0.99,
                                    eps=1e-8,
                                    momentum=0,
                                    weight_decay=0)
    # optimizer = torch.optim.Adam(agent.parameters(), lr=opt.agent_lr)
    if opt.load_prefix_sr == '':
        checkpoint_agent.save_prefix = sr_pretrain_dir + '/'
    else:
        checkpoint_agent.save_prefix = sr_pretrain_dir + '/' + opt.load_prefix_sr
        checkpoint_agent.load_prefix = checkpoint_agent.save_prefix[0:-1]
        checkpoint_agent.load_checkpoint(agent, optimizer, train_history)
        # adjust_lr(optimizer, opt.lr)
        # lost_joint_count_path = os.path.join(opt.exp_dir, opt.exp_id, opt.asdn_dir, 'joint-count-finetune.txt')
    print('agent: ', type(optimizer), optimizer.param_groups[0]['lr'])

    if opt.dataset == 'mpii':
        num_classes = 16
    hg = model.create_hg(num_stacks=2,
                         num_modules=1,
                         num_classes=num_classes,
                         chan=256)
    hg = torch.nn.DataParallel(hg).cuda()
    if opt.load_prefix_pose == '':
        print('please input the checkpoint name of the pose model')
        exit()
    checkpoint_hg = Checkpoint()
    # checkpoint_hg.save_prefix = os.path.join(opt.exp_dir, opt.exp_id, opt.resume_prefix_pose)
    checkpoint_hg.load_prefix = os.path.join(opt.exp_dir, opt.exp_id,
                                             opt.load_prefix_pose)[0:-1]
    checkpoint_hg.load_checkpoint(hg)

    logger = Logger(sr_pretrain_dir + '/' + 'training-summary.txt',
                    title='training-summary')
    logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss'])
    """training and validation"""
    start_epoch = 0
    if opt.load_prefix_sr != '':
        start_epoch = train_history.epoch[-1]['epoch'] + 1
    for epoch in range(start_epoch, opt.nEpochs):
        # train for one epoch
        train_loss = train(train_loader, hg, agent, optimizer, epoch,
                           visualizer, opt)
        val_loss = validate(val_loader, hg, agent, epoch, visualizer, opt)
        # update training history
        e = OrderedDict([('epoch', epoch)])
        lr = OrderedDict([('lr', optimizer.param_groups[0]['lr'])])
        loss = OrderedDict([('train_loss', train_loss),
                            ('val_loss', val_loss)])
        # pckh = OrderedDict( [('val_pckh', val_pckh)] )
        train_history.update(e, lr, loss)
        # print(train_history.lr[-1]['lr'])
        checkpoint_agent.save_checkpoint(agent,
                                         optimizer,
                                         train_history,
                                         is_asn=True)
        visualizer.plot_train_history(train_history, 'sr')
        logger.append(
            [epoch, optimizer.param_groups[0]['lr'], train_loss, val_loss])
    logger.close()
예제 #15
0
def main():
    opts = get_argparser().parse_args()
    if opts.dataset.lower() == 'voc':
        opts.num_classes = 21
        ignore_index = 255
    elif opts.dataset.lower() == 'cityscapes':
        opts.num_classes = 19
        ignore_index = 255
    elif opts.dataset.lower() == 'ade20k':
        opts.num_classes = 150
        ignore_index = -1
    elif opts.dataset.lower() == 'lvis':
        opts.num_classes = 1284
        ignore_index = -1
    elif opts.dataset.lower() == 'coco':
        opts.num_classes = 182
        ignore_index = 255
    if (opts.reduce_dim == False):
        opts.num_channels = opts.num_classes
    if (opts.test_only == False):
        writer = SummaryWriter('summary/' + opts.vis_env)
    # Setup visualization
    vis = Visualizer(port=opts.vis_port,
                     env=opts.vis_env) if opts.enable_vis else None
    if vis is not None:  # display options
        vis.vis_table("Options", vars(opts))

    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device: %s" % device)

    # Setup random seed
    torch.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    # Setup dataloader
    if opts.dataset == 'voc' and not opts.crop_val:
        opts.val_batch_size = 1

    train_dst, val_dst = get_dataset(opts)
    train_loader = data.DataLoader(train_dst,
                                   batch_size=opts.batch_size,
                                   shuffle=True,
                                   num_workers=2)
    val_loader = data.DataLoader(val_dst,
                                 batch_size=opts.val_batch_size,
                                 shuffle=False,
                                 num_workers=2)
    print("Dataset: %s, Train set: %d, Val set: %d" %
          (opts.dataset, len(train_dst), len(val_dst)))
    epoch_interval = int(len(train_dst) / opts.batch_size)
    if (epoch_interval > 5000):
        opts.val_interval = 5000
    else:
        opts.val_interval = epoch_interval
    print("Evaluation after %d iterations" % (opts.val_interval))

    # Set up model
    model_map = {
        #'deeplabv3_resnet50': network.deeplabv3_resnet50,
        'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,
        #'deeplabv3_resnet101': network.deeplabv3_resnet101,
        'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,
        #'deeplabv3_mobilenet': network.deeplabv3_mobilenet,
        'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet
    }
    if (opts.reduce_dim):
        num_classes_input = [opts.num_channels, opts.num_classes]
    else:
        num_classes_input = [opts.num_classes]
    model = model_map[opts.model](num_classes=num_classes_input,
                                  output_stride=opts.output_stride,
                                  reduce_dim=opts.reduce_dim)
    if opts.separable_conv and 'plus' in opts.model:
        network.convert_to_separable_conv(model.classifier)
    utils.set_bn_momentum(model.backbone, momentum=0.01)

    # Set up metrics
    metrics = StreamSegMetrics(opts.num_classes)
    if opts.reduce_dim:
        emb_layer = ['embedding.weight']
        params_classifier = list(
            map(
                lambda x: x[1],
                list(
                    filter(lambda kv: kv[0] not in emb_layer,
                           model.classifier.named_parameters()))))
        params_embedding = list(
            map(
                lambda x: x[1],
                list(
                    filter(lambda kv: kv[0] in emb_layer,
                           model.classifier.named_parameters()))))
        if opts.freeze_backbone:
            for param in model.backbone.parameters():
                param.requires_grad = False
            optimizer = torch.optim.SGD(
                params=[
                    #@{'params': model.backbone.parameters(),'lr':0.1*opts.lr},
                    {
                        'params': params_classifier,
                        'lr': opts.lr
                    },
                    {
                        'params': params_embedding,
                        'lr': opts.lr,
                        'momentum': 0.95
                    },
                ],
                lr=opts.lr,
                momentum=0.9,
                weight_decay=opts.weight_decay)
        else:
            optimizer = torch.optim.SGD(params=[
                {
                    'params': model.backbone.parameters(),
                    'lr': 0.1 * opts.lr
                },
                {
                    'params': params_classifier,
                    'lr': opts.lr
                },
                {
                    'params': params_embedding,
                    'lr': opts.lr
                },
            ],
                                        lr=opts.lr,
                                        momentum=0.9,
                                        weight_decay=opts.weight_decay)
    # Set up optimizer
    else:
        optimizer = torch.optim.SGD(params=[
            {
                'params': model.backbone.parameters(),
                'lr': 0.1 * opts.lr
            },
            {
                'params': model.classifier.parameters(),
                'lr': opts.lr
            },
        ],
                                    lr=opts.lr,
                                    momentum=0.9,
                                    weight_decay=opts.weight_decay)

    if opts.lr_policy == 'poly':
        scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)
    elif opts.lr_policy == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=opts.step_size,
                                                    gamma=0.1)
    elif opts.lr_policy == 'multi_poly':
        scheduler = utils.MultiPolyLR(optimizer,
                                      opts.total_itrs,
                                      power=[0.9, 0.9, 0.95])

    # Set up criterion
    if (opts.reduce_dim):
        opts.loss_type = 'nn_cross_entropy'
    else:
        opts.loss_type = 'cross_entropy'

    if opts.loss_type == 'cross_entropy':
        criterion = nn.CrossEntropyLoss(ignore_index=ignore_index,
                                        reduction='mean')
    elif opts.loss_type == 'nn_cross_entropy':
        criterion = utils.NNCrossEntropy(ignore_index=ignore_index,
                                         reduction='mean',
                                         num_neighbours=opts.num_neighbours,
                                         temp=opts.temp,
                                         dataset=opts.dataset)

    def save_ckpt(path):
        """ save current model
        """
        torch.save(
            {
                "cur_itrs": cur_itrs,
                "model_state": model.module.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "scheduler_state": scheduler.state_dict(),
                "best_score": best_score,
            }, path)
        print("Model saved as %s" % path)

    utils.mkdir(opts.checkpoint_dir)
    # Restore
    best_score = 0.0
    cur_itrs = 0
    cur_epochs = 0
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint["model_state"])
        model = nn.DataParallel(model)
        model.to(device)
        increase_iters = True
        if opts.continue_training:
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            cur_itrs = checkpoint["cur_itrs"]
            best_score = checkpoint['best_score']
            print("scheduler state dict :", scheduler.state_dict())
            print("Training state restored from %s" % opts.ckpt)
        print("Model restored from %s" % opts.ckpt)
        del checkpoint  # free memory
    else:
        print("[!] Retrain")
        model = nn.DataParallel(model)
        model.to(device)

    vis_sample_id = np.random.randint(
        0, len(val_loader), opts.vis_num_samples,
        np.int32) if opts.enable_vis else None  # sample idxs for visualization
    denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224,
                                    0.225])  # denormalization for ori images

    if opts.test_only:
        model.eval()
        val_score, ret_samples = validate(opts=opts,
                                          model=model,
                                          loader=val_loader,
                                          device=device,
                                          metrics=metrics,
                                          ret_samples_ids=vis_sample_id)
        print(metrics.to_str(val_score))
        return

    interval_loss = 0

    writer.add_text('lr', str(opts.lr))
    writer.add_text('batch_size', str(opts.batch_size))
    writer.add_text('reduce_dim', str(opts.reduce_dim))
    writer.add_text('checkpoint_dir', opts.checkpoint_dir)
    writer.add_text('dataset', opts.dataset)
    writer.add_text('num_channels', str(opts.num_channels))
    writer.add_text('num_neighbours', str(opts.num_neighbours))
    writer.add_text('loss_type', opts.loss_type)
    writer.add_text('lr_policy', opts.lr_policy)
    writer.add_text('temp', str(opts.temp))
    writer.add_text('crop_size', str(opts.crop_size))
    writer.add_text('model', opts.model)
    accumulation_steps = 1
    writer.add_text('accumulation_steps', str(accumulation_steps))
    j = 0
    updateflag = False
    while True:
        # =====  Train  =====
        model.train()
        cur_epochs += 1
        for (images, labels) in train_loader:
            cur_itrs += 1
            images = images.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.long)
            if (opts.dataset == 'ade20k' or opts.dataset == 'lvis'):
                labels = labels - 1

            optimizer.zero_grad()
            if (opts.reduce_dim):
                outputs, class_emb = model(images)
                loss = criterion(outputs, labels, class_emb)
            else:
                outputs = model(images)
                loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            model.zero_grad()
            j = j + 1
            np_loss = loss.detach().cpu().numpy()
            interval_loss += np_loss

            if vis is not None:
                vis.vis_scalar('Loss', cur_itrs, np_loss)
                vis.vis_scalar('LR', cur_itrs,
                               scheduler.state_dict()['_last_lr'][0])
            torch.cuda.empty_cache()
            del images, labels, outputs, loss
            if (opts.reduce_dim):
                del class_emb
            gc.collect()
            if (cur_itrs) % 50 == 0:
                interval_loss = interval_loss / 50
                print("Epoch %d, Itrs %d/%d, Loss=%f" %
                      (cur_epochs, cur_itrs, opts.total_itrs, interval_loss))
                writer.add_scalar('Loss', interval_loss, cur_itrs)
                writer.add_scalar('lr',
                                  scheduler.state_dict()['_last_lr'][0],
                                  cur_itrs)
            if cur_itrs % opts.val_interval == 0:
                save_ckpt(opts.checkpoint_dir + '/latest_%d.pth' % (cur_itrs))
            if cur_itrs % opts.val_interval == 0:
                print("validation...")
                model.eval()
                val_score, ret_samples = validate(
                    opts=opts,
                    model=model,
                    loader=val_loader,
                    device=device,
                    metrics=metrics,
                    ret_samples_ids=vis_sample_id)
                print(metrics.to_str(val_score))
                if val_score['Mean IoU'] > best_score:  # save best model
                    best_score = val_score['Mean IoU']
                    save_ckpt(opts.checkpoint_dir + '/best_%s_%s_os%d.pth' %
                              (opts.model, opts.dataset, opts.output_stride))

                writer.add_scalar('[Val] Overall Acc',
                                  val_score['Overall Acc'], cur_itrs)
                writer.add_scalar('[Val] Mean IoU', val_score['Mean IoU'],
                                  cur_itrs)
                writer.add_scalar('[Val] Mean Acc', val_score['Mean Acc'],
                                  cur_itrs)
                writer.add_scalar('[Val] Freq Acc', val_score['FreqW Acc'],
                                  cur_itrs)

                if vis is not None:  # visualize validation score and samples
                    vis.vis_scalar("[Val] Overall Acc", cur_itrs,
                                   val_score['Overall Acc'])
                    vis.vis_scalar("[Val] Mean IoU", cur_itrs,
                                   val_score['Mean IoU'])
                    vis.vis_table("[Val] Class IoU", val_score['Class IoU'])

                    for k, (img, target, lbl) in enumerate(ret_samples):
                        img = (denorm(img) * 255).astype(np.uint8)
                        if (opts.dataset.lower() == 'coco'):
                            target = numpy.asarray(
                                train_dst._colorize_mask(target).convert(
                                    'RGB')).transpose(2, 0, 1).astype(np.uint8)
                            lbl = numpy.asarray(
                                train_dst._colorize_mask(lbl).convert(
                                    'RGB')).transpose(2, 0, 1).astype(np.uint8)
                        else:
                            target = train_dst.decode_target(target).transpose(
                                2, 0, 1).astype(np.uint8)
                            lbl = train_dst.decode_target(lbl).transpose(
                                2, 0, 1).astype(np.uint8)
                        concat_img = np.concatenate(
                            (img, target, lbl), axis=2)  # concat along width
                        vis.vis_image('Sample %d' % k, concat_img)
                model.train()
            scheduler.step()
            if cur_itrs >= opts.total_itrs:
                return
    writer.close()
예제 #16
0
def main():
    # Parse the options from parameters
    opts = Opts().parse()
    ## For PyTorch 0.4.1, cuda(device)
    opts.device = torch.device(f'cuda:{opts.gpu[0]}')
    print(opts.expID, opts.task, os.path.dirname(os.path.realpath(__file__)))
    # Load the trained model test
    if opts.loadModel != 'none':
        model_path = os.path.join(opts.root_dir, opts.loadModel)
        model = torch.load(model_path).cuda(device=opts.device)
        model.eval()
    else:
        print('ERROR: No model is loaded!')
        return
    # Read the input image, pass input to gpu
    if opts.img == 'None':
        val_dataset = PENN_CROP(opts, 'val')
        val_loader = tud.DataLoader(val_dataset,
                                    batch_size=1,
                                    shuffle=False,
                                    num_workers=int(opts.num_workers))
        opts.nJoints = val_dataset.nJoints
        opts.skeleton = val_dataset.skeleton
        for i, gt in enumerate(val_loader):
            # Test Visualizer, Input and get_preds
            if i == 0:
                input, label = gt['input'], gt['label']
                gtpts, center, scale, proj = gt['gtpts'], gt['center'], gt[
                    'scale'], gt['proj']
                input_var = input[:, 0, ].float().cuda(device=opts.device,
                                                       non_blocking=True)
                # output = label
                output = model(input_var)
                # Test Loss, Err and Acc(PCK)
                Loss, Err, Acc = AverageMeter(), AverageMeter(), AverageMeter()
                ref = get_ref(opts.dataset, scale)
                for j in range(opts.preSeqLen):
                    pred = get_preds(output[:, j, ].cpu().float())
                    pred = original_coordinate(pred, center[:, ], scale,
                                               opts.outputRes)
                    err, ne = error(pred, gtpts[:, j, ], ref)
                    acc, na = accuracy(pred, gtpts[:, j, ], ref)
                    # assert ne == na, "ne must be the same as na"
                    Err.update(err)
                    Acc.update(acc)
                    print(j, f"{Err.val:.6f}", Acc.val)
                print('all', f"{Err.avg:.6f}", Acc.avg)
                # Visualizer Object
                ## Initialize
                v = Visualizer(opts.nJoints, opts.skeleton, opts.outputRes)
                # ## Add input image
                # v.add_img(input[0,0,].transpose(2, 0).numpy().astype(np.uint8))
                # ## Get the predicted joints
                # predJoints = get_preds(output[:, 0, ])
                # # ## Add joints and skeleton to the figure
                # v.add_2d_joints_skeleton(predJoints, (0, 0, 255))
                # Transform heatmap to show
                hm_img = output[0, 0, ].cpu().detach().numpy()
                v.add_hm(hm_img)
                ## Show image
                v.show_img(pause=True)
                break
    else:
        print('NOT ready for the raw input outside the dataset')
        img = cv2.imread(opts.img)
        input = torch.from_numpy(img.tramspose(2, 0, 1)).float() / 256.
        input = input.view(1, input.size(0), input.size(1), input.size(2))
        input_var = torch.autograd.variable(input).float().cuda(
            device=opts.device)
        output = model(input_var)
        predJoints = get_preds(output[-2].data.cpu().numpy())[0] * 4
def main():
    # parse args
    global args
    args = parse_args(sys.argv[1])
    args.during_training = False

    args.gpu_ids = list(
        range(len(os.environ['CUDA_VISIBLE_DEVICES'].split(','))))
    args.device = torch.device('cuda:0')
    args.test_size = args.batch_size // 4 * len(args.gpu_ids)

    # add timestamp to ckpt_dir
    args.timestamp = time.strftime('%m%d%H%M%S', time.localtime())
    args.ckpt_dir += '_' + args.timestamp

    # -------------------- init ckpt_dir, logging --------------------
    os.makedirs(args.ckpt_dir, mode=0o777, exist_ok=True)

    # -------------------- init visu --------------------
    visualizer = Visualizer(args)

    visualizer.logger.log('sys.argv:\n' + ' '.join(sys.argv))
    for arg in sorted(vars(args)):
        visualizer.logger.log('{:20s} {}'.format(arg, getattr(args, arg)))
    visualizer.logger.log('')

    # -------------------- dataset & loader --------------------
    test_dataset = datasets.__dict__[args.dataset](
        train=False,
        transform=transforms.Compose([
            transforms.Resize(args.imageSize, Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]),
        args=args)

    visualizer.logger.log('test_dataset: ' + str(test_dataset))

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=args.test_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True,
        worker_init_fn=lambda x: np.random.seed((torch.initial_seed()) %
                                                (2**32)))

    # -------------------- create model --------------------
    model_dict = {}

    G_input_nc = args.input_nc + args.passwd_length
    model_dict['G'] = models.define_G(G_input_nc,
                                      args.output_nc,
                                      args.ngf,
                                      args.which_model_netG,
                                      args.n_downsample_G,
                                      args.normG,
                                      args.dropout,
                                      args.init_type,
                                      args.init_gain,
                                      args.passwd_length,
                                      use_leaky=args.use_leakyG,
                                      use_resize_conv=args.use_resize_conv,
                                      padding_type=args.padding_type)
    model_dict['G_nets'] = [model_dict['G']]

    print('model_dict')
    for k, v in model_dict.items():
        print(k + ':')
        if isinstance(v, list):
            print('list, len:', len(v))
            print('')
        else:
            print(v)

    # -------------------- resume --------------------
    if args.resume:
        if osp.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            args.start_epoch = checkpoint['epoch'] + 1

            name = 'G'
            net = model_dict[name]
            if isinstance(net, torch.nn.DataParallel):
                net = net.module
            net.load_state_dict(checkpoint['state_dict_' + name])

            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
        gc.collect()
        torch.cuda.empty_cache()

    test(test_loader, model_dict, visualizer, args)
예제 #18
0
show_jot_opt(args)

# dataset
test_loader = data.DataLoader(create_dataset(args, 'test'),
                              args.batch_size_test,
                              num_workers=args.num_workers,
                              shuffle=False,
                              pin_memory=args.use_cuda)
train_loader = data.DataLoader(create_dataset(args, 'train'),
                               args.batch_size_train,
                               num_workers=args.num_workers,
                               shuffle=True,
                               pin_memory=args.use_cuda)

# init visualizer
visual = Visualizer(args)

# model
model = EncapNet(opts=args, num_classes=train_loader.dataset.num_classes)
# TODO (low): resume if program stops

if args.use_cuda:
    if len(args.device_id) == 1:
        model = model.cuda() if not args.pt_new else model.to(args.device)
        print_log('single gpu mode', args.file_name)
    else:
        model = torch.nn.DataParallel(model).cuda() \
            if not args.pt_new else torch.nn.DataParallel(model.to(args.device))
        print_log('multi-gpu mode', args.file_name)
else:
    raise NotImplementedError('we do not like cpu mode ...')
예제 #19
0
def main():
    opt = TrainOptions().parse()
    train_history = TrainHistory()
    checkpoint = Checkpoint()
    visualizer = Visualizer(opt)
    exp_dir = os.path.join(opt.exp_dir, opt.exp_id)
    log_name = opt.vis_env + 'log.txt'
    visualizer.log_name = os.path.join(exp_dir, log_name)
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id
    # if opt.dataset == 'mpii':
    num_classes = 16
    # layer_num = 2
    net = create_cu_net(neck_size=4,
                        growth_rate=32,
                        init_chan_num=128,
                        class_num=num_classes,
                        layer_num=opt.layer_num,
                        order=1,
                        loss_num=opt.layer_num)
    # num1 = get_n_params(net)
    # num2 = get_n_trainable_params(net)
    # num3 = get_n_conv_params(net)
    # print 'number of params: ', num1
    # print 'number of trainalbe params: ', num2
    # print 'number of conv params: ', num3
    # torch.save(net.state_dict(), 'test-model-size.pth.tar')
    # exit()
    # device = torch.device("cuda:0")
    # net = net.to(device)
    net = torch.nn.DataParallel(net).cuda()
    global bin_op
    bin_op = BinOp(net)
    optimizer = torch.optim.RMSprop(net.parameters(),
                                    lr=opt.lr,
                                    alpha=0.99,
                                    eps=1e-8,
                                    momentum=0,
                                    weight_decay=0)
    """optionally resume from a checkpoint"""
    if opt.resume_prefix != '':
        # if 'pth' in opt.resume_prefix:
        #     trunc_index = opt.resume_prefix.index('pth')
        #     opt.resume_prefix = opt.resume_prefix[0:trunc_index - 1]
        # checkpoint.save_prefix = os.path.join(exp_dir, opt.resume_prefix)
        checkpoint.save_prefix = exp_dir + '/'
        checkpoint.load_prefix = os.path.join(exp_dir, opt.resume_prefix)[0:-1]
        checkpoint.load_checkpoint(net, optimizer, train_history)
        opt.lr = optimizer.param_groups[0]['lr']
        resume_log = True
    else:
        checkpoint.save_prefix = exp_dir + '/'
        resume_log = False
    print 'save prefix: ', checkpoint.save_prefix
    # model = {'state_dict': net.state_dict()}
    # save_path = checkpoint.save_prefix + 'test-model-size.pth.tar'
    # torch.save(model, save_path)
    # exit()
    """load data"""
    train_loader = torch.utils.data.DataLoader(MPII(
        'dataset/mpii-hr-lsp-normalizer.json',
        '/bigdata1/zt53/data',
        is_train=True),
                                               batch_size=opt.bs,
                                               shuffle=True,
                                               num_workers=opt.nThreads,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(MPII(
        'dataset/mpii-hr-lsp-normalizer.json',
        '/bigdata1/zt53/data',
        is_train=False),
                                             batch_size=opt.bs,
                                             shuffle=False,
                                             num_workers=opt.nThreads,
                                             pin_memory=True)
    """optimizer"""
    # optimizer = torch.optim.SGD( net.parameters(), lr=opt.lr,
    #                             momentum=opt.momentum,
    #                             weight_decay=opt.weight_decay )
    # optimizer = torch.optim.RMSprop(net.parameters(), lr=opt.lr, alpha=0.99,
    #                                 eps=1e-8, momentum=0, weight_decay=0)
    print type(optimizer)
    # idx = range(0, 16)
    # idx = [e for e in idx if e not in (6, 7, 8, 9, 12, 13)]
    idx = [0, 1, 2, 3, 4, 5, 10, 11, 14, 15]
    logger = Logger(os.path.join(opt.exp_dir, opt.exp_id,
                                 'training-summary.txt'),
                    title='training-summary',
                    resume=resume_log)
    logger.set_names(
        ['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'])
    if not opt.is_train:
        visualizer.log_path = os.path.join(opt.exp_dir, opt.exp_id,
                                           'val_log.txt')
        val_loss, val_pckh, predictions = validate(
            val_loader, net, train_history.epoch[-1]['epoch'], visualizer, idx,
            joint_flip_index, num_classes)
        checkpoint.save_preds(predictions)
        return
    """training and validation"""
    start_epoch = 0
    if opt.resume_prefix != '':
        start_epoch = train_history.epoch[-1]['epoch'] + 1
    for epoch in range(start_epoch, opt.nEpochs):
        adjust_lr(opt, optimizer, epoch)
        # # train for one epoch
        train_loss, train_pckh = train(train_loader, net, optimizer, epoch,
                                       visualizer, idx, opt)

        # evaluate on validation set
        val_loss, val_pckh, predictions = validate(val_loader, net, epoch,
                                                   visualizer, idx,
                                                   joint_flip_index,
                                                   num_classes)
        # visualizer.display_imgpts(imgs, pred_pts, 4)
        # exit()
        # update training history
        e = OrderedDict([('epoch', epoch)])
        lr = OrderedDict([('lr', optimizer.param_groups[0]['lr'])])
        loss = OrderedDict([('train_loss', train_loss),
                            ('val_loss', val_loss)])
        pckh = OrderedDict([('val_pckh', val_pckh)])
        train_history.update(e, lr, loss, pckh)
        checkpoint.save_checkpoint(net, optimizer, train_history, predictions)
        # visualizer.plot_train_history(train_history)
        logger.append([
            epoch, optimizer.param_groups[0]['lr'], train_loss, val_loss,
            train_pckh, val_pckh
        ])
    logger.close()
예제 #20
0
    # beta
    train_opt.lambda_E = 0.01
    # gamma_3
    train_opt.lambda_F = 10000

    train_dataloader = get_dataloader(train_opt, isTrain=True)
    dataset_size = len(train_dataloader)
    train_model = create_model(train_opt, train_dataloader.hsi_channels,
                               train_dataloader.msi_channels,
                               train_dataloader.lrhsi_height,
                               train_dataloader.lrhsi_width,
                               train_dataloader.sp_matrix,
                               train_dataloader.sp_range)

    train_model.setup(train_opt)
    visualizer = Visualizer(train_opt, train_dataloader.sp_matrix)

    total_steps = 0

    for epoch in range(train_opt.epoch_count,
                       train_opt.niter + train_opt.niter_decay + 1):

        epoch_start_time = time.time()
        iter_data_time = time.time()
        epoch_iter = 0

        train_psnr_list = []

        for i, data in enumerate(train_dataloader):

            iter_start_time = time.time()
예제 #21
0
        best_iou = np.loadtxt(ioupath_path, dtype=float)
    except:
        best_iou = 0
    print('Resuming from epoch %d at iteration %d, previous best IoU: %f' %
          (start_epoch, epoch_iter, best_iou))
else:
    start_epoch, epoch_iter = 1, 0
    best_iou = 0.

data_loader = CreateDataLoader(opt)
dataset, dataset_val = data_loader.load_data()
dataset_size = len(dataset)
print('#training images = %d' % dataset_size)

model = create_model(opt, dataset.dataset)
visualizer = Visualizer(opt)
total_steps = (start_epoch - 1) * dataset_size + epoch_iter
pre_compute_flag = 0

print("Precompute weight for 5 epoches")
for pretrain_epoch in range(5):
    print('epoch: %d' % (pretrain_epoch))
    model.model.train()
    for i, data in enumerate(dataset, start=epoch_iter):
        model.pre_compute_W(i, data)

pre_compute_flag = 1

for epoch in range(start_epoch, opt.nepochs):
    epoch_start_time = time.time()
    if epoch != start_epoch:
예제 #22
0
import time

from options.options import Options
from models import audio_expression_model
from datasets import create_dataset
from utils.visualizer import Visualizer


if __name__ == '__main__':
    opt = Options().parse_args()   # get training options

    dataset = create_dataset(opt)

    model = audio_expression_model.AudioExpressionModel(opt)

    visualizer = Visualizer(opt)   # create a visualizer that display/save images and plots

    total_iters = 0

    for epoch in range(opt.num_epoch):

        epoch_start_time = time.time()  # timer for entire epoch
        iter_data_time = time.time()    # timer for data loading per iteration
        epoch_iter = 0                  # the number of training iterations in current epoch, reset to 0 every epoch

        for i, data in enumerate(dataset):  # inner loop within one epoch

            iter_start_time = time.time()  # timer for computation per iteration
            if total_iters % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time
def main():

    opts = get_argparser().parse_args()

    save_dir = os.path.join(opts.save_dir + opts.model + '/')
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    print('Save position is %s\n' % (save_dir))

    # Setup visualization
    vis = Visualizer(port=opts.vis_port,
                     env=opts.vis_env) if opts.enable_vis else None
    if vis is not None:  # display options
        vis.vis_table("Options", vars(opts))

    # select the GPU
    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device: %s,  CUDA_VISIBLE_DEVICES: %s\n" % (device, opts.gpu_id))

    # Setup random seed
    torch.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    train_dst, val_dst = get_dataset(opts)
    train_loader = data.DataLoader(train_dst,
                                   batch_size=opts.batch_size,
                                   shuffle=True,
                                   num_workers=8,
                                   drop_last=True,
                                   pin_memory=False)
    val_loader = data.DataLoader(val_dst,
                                 batch_size=opts.batch_size,
                                 shuffle=True,
                                 num_workers=8,
                                 drop_last=True,
                                 pin_memory=False)
    print("Dataset: %s, Train set: %d, Val set: %d" %
          (opts.dataset, len(train_dst), len(val_dst)))

    # Set up model
    model_map = {
        'self_contrast': network.self_contrast,
        'DCNet_L1': network.DCNet_L1,
        'DCNet_L12': network.DCNet_L12,
        'DCNet_L123': network.DCNet_L123,
        'FCN': network.FCN,
        'UNet': network.UNet,
        'SegNet': network.SegNet,
        'cloudSegNet': network.cloudSegNet,
        'cloudUNet': network.cloudUNet
    }

    print('Model = %s, num_classes=%d' % (opts.model, opts.num_classes))
    model = model_map[opts.model](n_classes=opts.num_classes,
                                  is_batchnorm=True,
                                  in_channels=opts.in_channels,
                                  feature_scale=opts.feature_scale,
                                  is_deconv=False)

    # Set up metrics
    metrics = StreamSegMetrics(opts.num_classes)

    # Set up optimizer
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=opts.lr,
                                momentum=0.9,
                                weight_decay=opts.weight_decay)

    if opts.lr_policy == 'poly':
        scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)
    elif opts.lr_policy == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=opts.step_size,
                                                    gamma=0.5)

    # Set up criterion
    if opts.loss_type == 'focal_loss':
        criterion = utils.FocalLoss()
    elif opts.loss_type == 'cross_entropy':
        criterion = nn.CrossEntropyLoss()

    def save_ckpt(path):
        """ save current model
        """
        torch.save(
            {
                "cur_itrs": cur_itrs,
                "model_state": model.module.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "scheduler_state": scheduler.state_dict(),
                "best_score": best_score,
            }, path)
        print("Model saved as %s\n\n" % path)

    # Restore
    best_score = 0.0
    cur_itrs = 0
    cur_epochs = 0
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
        model_dict = model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in checkpoint["model_state"].items() if (k in model_dict)
        }
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
        model = nn.DataParallel(model)
        model.to(device)
        if opts.continue_training:
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            scheduler.max_iters = opts.total_itrs
            scheduler.min_lr = opts.lr
            cur_itrs = checkpoint["cur_itrs"]
            best_score = checkpoint['best_score']
            print("Continue training state restored from %s" % opts.ckpt)
        print("Model restored from %s" % opts.ckpt)
        print("Best_score is %s" % (str(best_score)))
        # del checkpoint  # free memory
    else:
        print("[!] Retrain")
        model = nn.DataParallel(model)
        model.to(device)

    # ==========   Train Loop   ==========#
    vis_sample_id = np.random.randint(
        0, len(val_loader), opts.vis_num_samples,
        np.int32) if opts.enable_vis else None  # sample idxs for visualization

    interval_loss = 0
    train_loss = list()
    train_accuracy = list()
    best_val_itrs = list()
    while True:  # cur_itrs < opts.total_itrs:
        # =====  Train  =====
        model.train()
        cur_epochs += 1

        for (images, labels) in train_loader:
            if (cur_itrs) == 0 or (cur_itrs) % opts.print_interval == 0:
                t1 = time.time()

            cur_itrs += 1

            images = images.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.long)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            np_loss = loss.detach().cpu().numpy()
            interval_loss += np_loss

            if (cur_itrs) % opts.print_interval == 0:
                interval_loss = interval_loss / opts.print_interval
                train_loss.append(interval_loss)
                t2 = time.time()
                print("Epoch %d, Itrs %d/%d, Loss=%f, Time = %f" %
                      (cur_epochs, cur_itrs, opts.total_itrs, interval_loss,
                       t2 - t1))
                interval_loss = 0.0

            # save the ckpt file per 5000 itrs
            if (cur_itrs) % opts.val_interval == 0:
                print("validation...")
                model.eval()

                save_ckpt(save_dir + 'latest_%s_%s_itrs%s.pth' %
                          (opts.model, opts.dataset, str(cur_itrs)))
                time_before_val = time.time()
                val_score, ret_samples = validate(
                    opts=opts,
                    model=model,
                    loader=val_loader,
                    device=device,
                    metrics=metrics,
                    ret_samples_ids=vis_sample_id)

                time_after_val = time.time()
                print('Time_val = %f' % (time_after_val - time_before_val))
                print(metrics.to_str(val_score))

                train_accuracy.append(val_score['overall_acc'])
                if val_score['overall_acc'] > best_score:  # save best model
                    best_score = val_score['overall_acc']
                    save_ckpt(save_dir + 'best_%s_%s_.pth' %
                              (opts.model, opts.dataset))
                    best_val_itrs.append(cur_itrs)
                model.train()
            scheduler.step()  # update

            if cur_itrs >= opts.total_itrs:
                print(cur_itrs)
                print(opts.total_itrs)
                return
예제 #24
0
from train import train
from val import val

if __name__ == "__main__":

    opt = Opt().parse()

    ########################################
    #                 Model                #
    ########################################
    torch.manual_seed(opt.manual_seed)

    if opt.no_vis:
        visualizer = None
    else:
        visualizer = Visualizer(opt)
    model = get_model(opt)
    if opt.optimizer == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
    elif opt.optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(), lr=opt.lr)
    else:
        NotImplementedError("Only Adam and SGD are supported")

    ########################################
    #              Transforms              #
    ########################################
    if not opt.no_train:
        train_transforms = get_train_transforms(opt)
        dataset = get_train_dataset(opt, train_transforms)
        dataloader = torch.utils.data.DataLoader(dataset,
예제 #25
0
def main():
    log_dir = os.path.join('logs', '000')
    opt = Config()
    if opt.display:
        visualizer = Visualizer()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.manual_seed(1)

    train_dataset = Dataset(opt.train_root,
                            opt.train_list,
                            phase='train',
                            input_shape=opt.input_shape)
    trainloader = data.DataLoader(train_dataset,
                                  batch_size=opt.train_batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers)

    #identity_list = get_lfw_list(opt.lfw_test_list)
    #img_paths = [os.path.join(opt.lfw_root, each) for each in identity_list]

    if opt.loss == 'focal_loss':
        criterion = FocalLoss(gamma=2)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    if opt.backbone == 'resnet18':
        model = resnet_face18(use_se=opt.use_se)
    elif opt.backbone == 'resnet34':
        #model = resnet34()
        model = resnet_face34(use_se=opt.use_se)
    elif opt.backbone == 'resnet50':
        model = resnet50()

    if opt.metric == 'add_margin':
        metric_fc = AddMarginProduct(512,
                                     opt.num_classes,
                                     s=30,
                                     m=0.35,
                                     device=device)
    elif opt.metric == 'arc_margin':
        metric_fc = ArcMarginProduct(512,
                                     opt.num_classes,
                                     s=30,
                                     m=0.5,
                                     easy_margin=opt.easy_margin,
                                     device=device)
    elif opt.metric == 'sphere':
        metric_fc = SphereProduct(512, opt.num_classes, m=4, device=device)
    else:
        metric_fc = torch.nn.Linear(512, opt.num_classes)

    # view_model(model, opt.input_shape)
    #print(model)
    model.to(device)
    summary(model, input_size=opt.input_shape)
    model = DataParallel(model)
    metric_fc.to(device)
    metric_fc = DataParallel(metric_fc)

    print('{} train iters per epoch:'.format(len(trainloader)))

    if opt.optimizer == 'sgd':
        optimizer = torch.optim.SGD([{
            'params': model.parameters()
        }, {
            'params': metric_fc.parameters()
        }],
                                    lr=opt.lr,
                                    weight_decay=opt.weight_decay)
    else:
        optimizer = torch.optim.Adam([{
            'params': model.parameters()
        }, {
            'params': metric_fc.parameters()
        }],
                                     lr=opt.lr,
                                     weight_decay=opt.weight_decay)

    scheduler = StepLR(optimizer, step_size=opt.lr_step, gamma=0.1)

    #start = time.time()
    for epoch in range(opt.max_epoch):
        scheduler.step()
        print('Epoch %d/%d' % (epoch, opt.max_epoch))
        train(opt, model, metric_fc, device, trainloader, criterion, optimizer,
              scheduler)
        validate(opt, model, device, epoch, log_dir)
예제 #26
0
            return ds_id - 1

    else:
        raise ValueError("Unsupported dataset: {}".format(args.dataset))

    os.makedirs(args.output, exist_ok=True)

    for dic in tqdm.tqdm(dicts):
        flag = False
        for ann in dic['annotations']:
            category = metadata.get('thing_classes')[ann['category_id']]

        img = cv2.imread(dic["file_name"], cv2.IMREAD_COLOR)[:, :, ::-1]
        basename = os.path.basename(dic["file_name"])

        vis = Visualizer(img, metadata)
        vis_pred = vis.draw_proposals_separately(
            proposal_by_image[dic["image_id"]], img.shape[:2],
            args.conf_threshold)

        predictions = create_instances(pred_by_image[dic["image_id"]],
                                       img.shape[:2])
        vis = Visualizer(img, metadata)
        pred = vis.draw_instance_predictions(predictions).get_image()
        vis_pred.append(pred)
        vis = Visualizer(img, metadata)
        gt = vis.draw_dataset_dict(dic).get_image()

        concat = vis.smart_concatenate(vis_pred, min_side=1960)
        vis = np.concatenate([pred, gt], axis=1)
예제 #27
0
class Treainer(object):
    def __init__(self,
                 opt=None,
                 train_dt=None,
                 train_dt_warm=None,
                 dis_list=[],
                 val_dt_warm=None):
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.opt = opt

        self.visualizer = Visualizer(opt)

        num_gpus = torch.cuda.device_count()
        #dis_list[1]
        print(dis_list)
        #torch.cuda.device_count()
        self.rank = dis_list[0]
        print(self.rank)

        #=====START: ADDED FOR DISTRIBUTED======
        if num_gpus > 1:
            #init_distributed(rank, num_gpus, group_name, **dist_config)
            dist_config = dis_list[3]
            init_distributed(dis_list[0], dis_list[1], dis_list[2],
                             **dist_config)
        #=====END:   ADDED FOR DISTRIBUTED======

        if opt.ge_net == "srfeat":
            self.netG = model.G()
        elif opt.ge_net == "carn":
            self.netG = model.G1()
        elif opt.ge_net == "carnm":
            self.netG = model.G2()
        else:
            raise Exception("unknow ")

        self.netD_vgg = model.D(input_c=512, input_width=18)

        self.netD = model.D()

        if opt.vgg_type == "style":
            self.vgg = load_vgg16(opt.vgg_model_path + '/models')
        elif opt.vgg_type == "classify":
            self.vgg = model.vgg19_withoutbn_customefinetune()

        self.vgg.eval()
        for param in self.vgg.parameters():
            param.requires_grad = False

#         for p in self.vgg.parameters():
#             p.requires_grad = False

        init_weights(self.netD, init_type=opt.init)
        init_weights(self.netD_vgg, init_type=opt.init)
        init_weights(self.netG, init_type=opt.init)

        self.vgg = self.vgg.to(self.device)
        self.netD = self.netD.to(self.device)
        self.netD_vgg = self.netD_vgg.to(self.device)
        self.netG = self.netG.to(self.device)

        #=====START: ADDED FOR DISTRIBUTED======
        if num_gpus > 1:
            #self.vgg = apply_gradient_allreduce(self.vgg)
            self.netD_vgg = apply_gradient_allreduce(self.netD_vgg)
            self.netD = apply_gradient_allreduce(self.netD)
            self.netG = apply_gradient_allreduce(self.netG)

        #=====END:   ADDED FOR DISTRIBUTED======

        print(opt)

        self.optim_G= torch. optim.Adam(filter(lambda p: p.requires_grad, self.netG.parameters()),\
         lr=opt.warm_opt.lr, betas=opt.warm_opt.betas, weight_decay=0.0)

        #        self.optim_G= torch.optim.Adam(filter(lambda p: p.requires_grad, self.netG.parameters()),\
        #         lr=opt.gen.lr, betas=opt.gen.betas, weight_decay=0.0)

        if opt.dis.optim == "sgd":
            self.optim_D= torch.optim.SGD( filter(lambda p: p.requires_grad, \
                itertools.chain(self.netD_vgg.parameters(),self.netD.parameters() ) ),\
                lr=opt.dis.lr,
             )
        elif opt.dis.optim == "adam":
            self.optim_D= torch.optim.Adam( filter(lambda p: p.requires_grad, \
                itertools.chain(self.netD_vgg.parameters(),self.netD.parameters() ) ),\
                lr=opt.dis.lr,betas=opt.dis.betas, weight_decay=0.0
             )
        else:
            raise Exception("unknown")

        print("create schedule ")

        lr_sc_G = get_scheduler(self.optim_G, opt.gen)
        lr_sc_D = get_scheduler(self.optim_D, opt.dis)

        self.schedulers = []

        self.schedulers.append(lr_sc_G)
        self.schedulers.append(lr_sc_D)

        # =====START: ADDED FOR DISTRIBUTED======
        train_dt = torch.utils.data.ConcatDataset([train_dt, train_dt_warm])

        train_sampler = DistributedSampler(train_dt) if num_gpus > 1 else None
        val_sampler_warm = DistributedSampler(
            val_dt_warm) if num_gpus > 1 else None
        # =====END:   ADDED FOR DISTRIBUTED======

        kw = {
            "pin_memory": True,
            "num_workers": 8
        } if torch.cuda.is_available() else {}
        dl_c =t_data.DataLoader(train_dt ,batch_size=opt.batch_size,\
             sampler=train_sampler , drop_last=True, **kw )

        dl_val_warm = t_data.DataLoader(
            val_dt_warm,
            batch_size=opt.batch_size
            if not hasattr(opt, "batch_size_warm") else opt.batch_size_warm,
            sampler=val_sampler_warm,
            drop_last=True,
            **kw)

        self.dt_train = dl_c
        self.dt_val_warm = dl_val_warm

        if opt.warm_opt.loss_fn == "mse":
            self.critic_pixel = torch.nn.MSELoss()
        elif opt.warm_opt.loss_fn == "l1":
            self.critic_pixel = torch.nn.L1Loss()
        elif opt.warm_opt.loss_fn == "smooth_l1":
            self.critic_pixel = torch.nn.SmoothL1Loss()
        else:
            raise Exception("unknown")

        self.critic_pixel = self.critic_pixel.to(self.device)

        self.gan_loss = GANLoss(gan_mode=opt.gan_loss_fn).to(self.device)
        print("init ....")

        self.save_dir = os.path.dirname(self.visualizer.log_name)

    def _validate_(self):
        with torch.no_grad():
            print("val ," * 8, "warm start...", len(self.dt_val_warm))
            iter_start_time = time.time()
            ssim = []
            batch_loss = []
            psnr = []

            cub_ssim = []
            cub_batch_loss = []
            cub_psnr = []

            save_image_list_1 = []

            for ii, data in tqdm.tqdm(enumerate(self.dt_val_warm)):
                if len(data) > 3:
                    input_lr, input_hr, cubic_hr, _, _ = data
                else:
                    input_lr, input_hr, cubic_hr = data

                self.input_lr = input_lr.to(self.device)
                self.input_hr = input_hr.to(self.device)
                self.input_cubic_hr = cubic_hr.to(self.device)

                self.forward()

                save_image_list_1.append(torch.cat( [self.input_cubic_hr ,\
                 self.output_hr ,\
                 self.input_hr ],dim=3)  )

                loss = self.critic_pixel(self.output_hr, self.input_hr)
                batch_loss.append(loss.item())
                ssim.append(
                    image_quality.msssim(self.output_hr, self.input_hr).item())
                psnr.append(
                    image_quality.psnr(self.output_hr, self.input_hr).item())

                cub_loss = self.critic_pixel(self.input_cubic_hr,
                                             self.input_hr)
                cub_batch_loss.append(cub_loss.item())
                cub_ssim.append(
                    image_quality.msssim(self.input_cubic_hr,
                                         self.input_hr).item())
                cub_psnr.append(
                    image_quality.psnr(self.input_cubic_hr,
                                       self.input_hr).item())

            np.random.shuffle(save_image_list_1)
            save_image_list = save_image_list_1[:8]
            save_image_list = util.tensor2im(torch.cat(save_image_list, dim=2))
            save_image_list = OrderedDict([("cub_out_gt", save_image_list)])
            self.visualizer.display_current_results(save_image_list,
                                                    self.epoch,
                                                    save_result=True,
                                                    offset=20,
                                                    title="val_imag")

            val_info = (np.mean(batch_loss), np.mean(ssim), np.mean(psnr),
                        np.mean(cub_batch_loss), np.mean(cub_ssim),
                        np.mean(cub_psnr))
            errors = dict(
                zip(("loss", "ssim", "psnr", "cub_loss", "cub_ssim",
                     "cub_psnr"), val_info))
            t = (time.time() - iter_start_time)
            self.visualizer.print_current_errors(self.epoch,
                                                 self.epoch,
                                                 errors,
                                                 t,
                                                 log_name="loss_log_val.txt")
            self.visualizer.plot_current_errors(self.epoch,
                                                self.epoch,
                                                opt=None,
                                                errors=errors,
                                                display_id_offset=3,
                                                loss_name="val")

            return val_info

    def run(self):
        current_epoch = self.load_networks()
        self._run_train()

    def _run_train(self):
        print("train.i..." * 8)
        total_steps = 0
        opt = self.opt

        self.model_names = ["G", "D", "D_vgg"]

        self.loss_w_g = torch.tensor(0)
        dataset_size = len(self.dt_train) * opt.batch_size
        best_loss = 10e5

        for epoch in range(0, self.opt.epoches_warm + self.opt.epoches):
            self.epoch = epoch
            #             epoch_start_time = time.time()
            epoch_iter = 0

            val_loss = self._validate_()
            val_loss = val_loss[0]
            if best_loss > val_loss:
                best_loss = val_loss
                self.save_networks("best")
            self.save_networks(epoch)

            for data in self.dt_train:
                if len(data) > 3:
                    input_lr, input_hr, cubic_hr, _, _ = data
                else:
                    input_lr, input_hr, cubic_hr = data

                iter_start_time = time.time()

                self.input_lr = input_lr.to(self.device)
                self.input_hr = input_hr.to(self.device)
                self.input_cubic_hr = cubic_hr

                self.forward()

                self.optim_G.zero_grad()
                self.g_loss()
                self.optim_G.step()

                self.optim_D.zero_grad()
                self.d_loss()
                self.optim_D.step()

                self.visualizer.reset()
                total_steps += opt.batch_size
                epoch_iter += opt.batch_size

                if total_steps % opt.display_freq == 0:
                    save_result = total_steps % opt.update_html_freq == 0
                    self.visualizer.display_current_results(
                        self.get_current_visuals(), epoch, save_result)

                if total_steps % opt.print_freq == 0:
                    errors = self.get_current_errors()
                    t = (time.time() - iter_start_time) / opt.batch_size
                    self.visualizer.print_current_errors(
                        epoch, epoch_iter, errors, t)
                    if opt.display_id > 0:
                        self.visualizer.plot_current_errors(
                            epoch,
                            float(epoch_iter) / dataset_size, opt, errors)

                if self.rank != 0:
                    continue
            lr_g, lr_d = self.update_learning_rate(is_warm=False)
            self.visualizer.plot_current_lrs(epoch,0,opt=None,\
                errors=OrderedDict([ ('lr_warm_g',0),("lr_g",lr_g),("lr_d",lr_d) ]) ,  loss_name="lr_warm"  ,display_id_offset=1)

    def forward(self, ):
        self.output_hr = self.netG(self.input_lr)
        #         self.input_hr
        pass

    def g_loss(self, ):
        #print (self.opt.gen,type(self.opt.gen),self.opt.gen.keys())
        vgg_r = self.opt.gen.lambda_vgg_input
        #g feature f
        x_f_fake = self.vgg(vgg_r * self.output_hr)

        #g .. f
        d_fake = self.netD(self.output_hr)
        self.loss_G_g = self.opt.gen.lambda_vgg_loss * self.gan_loss(
            d_fake, True)

        fd_fake = self.netD_vgg(x_f_fake)
        self.loss_G_fg = self.opt.gen.lambda_vgg_loss * self.gan_loss(
            fd_fake, True)

        ## perception
        x_f_real = self.vgg(vgg_r * self.input_hr)
        self.loss_G_p = self.critic_pixel(x_f_fake, x_f_real)

        self.loss_w_g = self.opt.warm_opt.lambda_warm_loss * self.critic_pixel(
            self.output_hr, self.input_hr)

        self.loss_g =  self.loss_G_g + self.loss_G_fg  + self.loss_G_p +\
             self.loss_w_g

        self.loss_g.backward()

        if hasattr(self.opt.warm_opt, "clip"):
            nn.utils.clip_grad_norm(self.netG.parameters(),
                                    self.opt.warm_opt.clip)

    def d_loss(self, ):
        d_fake = self.netD(self.output_hr.detach())
        d_real = self.netD(self.input_hr)

        vgg_r = self.opt.gen.lambda_vgg_input
        x_f_fake = self.vgg(vgg_r * self.output_hr.detach())
        x_f_real = self.vgg(vgg_r * self.input_hr)

        vgg_d_fake = self.netD_vgg(x_f_fake)
        vgg_d_real = self.netD_vgg(x_f_real)

        self.loss_D_f = self.gan_loss(d_fake, False)
        self.loss_D_r = self.gan_loss(d_real, True)

        self.loss_Df_f = self.gan_loss(vgg_d_fake, False)
        self.loss_Df_r = self.gan_loss(vgg_d_real, True)

        #self.loss_d_f_fake = 0
        #self.loss_d_f_real = 0
        if self.opt.gan_loss_fn == "wgangp":
            # train with gradient penalty
            gradient_penalty_vgg,_ = cal_gradient_penalty(netD=self.netD_vgg, real_data=x_f_real.data,\
                fake_data=x_f_fake.data,device=self.device)
            gradient_penalty_vgg.backward()

            gradient_penalty,_ = cal_gradient_penalty(netD=self.netD, real_data=self.input_hr.data, \
                fake_data = self.output_hr.data,  device=self.device)
            gradient_penalty.backward()



        loss_d =self.loss_D_f+ self.loss_D_r +self.loss_Df_f+\
            self.loss_Df_r
        #print ("loss_d",loss_d.item() )

        loss_d.backward()

    def get_current_errors(self):
        return OrderedDict([
            ('G_p', self.loss_G_p.item() if hasattr(self, "loss_G_p") else 0),
            ('G_fg',
             self.loss_G_fg.item() if hasattr(self, "loss_G_fg") else 0),
            ('G_g', self.loss_G_g.item() if hasattr(self, "loss_G_g") else 0),
            ('D_f_real',
             self.loss_Df_r.item() if hasattr(self, "loss_Df_r") else 0),
            ('D_f_fake',
             self.loss_Df_f.item() if hasattr(self, "loss_Df_f") else 0),
            ('D_real',
             self.loss_D_r.item() if hasattr(self, "loss_D_r") else 0),
            ('D_fake',
             self.loss_D_f.item() if hasattr(self, "loss_D_f") else 0),
            ('warm_p',
             self.loss_w_g.item() if hasattr(self, "loss_w_g") else 0),
        ])

    def get_current_visuals(self):
        input = util.tensor2im(self.input_cubic_hr)
        target = util.tensor2im(self.input_hr)
        fake = util.tensor2im(self.output_hr.detach())
        return OrderedDict([('input', input), ('fake', fake),
                            ('target', target)])

    def update_learning_rate(self, is_warm=True):
        if True:
            for scheduler in self.schedulers:
                scheduler.step()

            lr_g = self.optim_G.param_groups[0]['lr']
            lr_d = self.optim_D.param_groups[0]['lr']
            return (lr_g, lr_d)

    def save_networks(self, epoch):
        """Save all the networks to the disk.
        Parameters:
            epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
        """
        for name in self.model_names:
            if isinstance(name, str):
                save_filename = '%s_net_%s.pth' % (epoch, name)
                save_path = os.path.join(self.save_dir, save_filename)
                net = getattr(self, 'net' + name)

                if "parallel" in str(type(net)) and torch.cuda.is_available():
                    torch.save(net.module.cpu().state_dict(), save_path)
                    net.cuda(self.gpu_ids[0])
                else:
                    torch.save(net.cpu().state_dict(), save_path)

                net.to(self.device)

    def load_networks(self, epoch=None):
        """Load all the networks from the disk.
        Parameters:
            epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
        """
        pth_list = [os.path.basename(x) for x in os.listdir(self.save_dir)]
        pth_list = [
            x.split("_")[0] for x in pth_list
            if "_net" in x and ".pth" in x and "best" not in x
        ]
        pth_list = sorted(pth_list)[:-1]
        pth_list = list(map(int, pth_list))
        pth_list = sorted(pth_list)
        current_epoch = 0
        try:
            current_epoch = int(pth_list[-1])
        except:
            pass

        if current_epoch <= 0:
            return current_epoch

        epoch = current_epoch
        #for name in self.model_names:
        for name in ["G", "D", "D_vgg"]:
            if isinstance(name, str):
                load_filename = '%s_net_%s.pth' % (epoch, name)
                load_path = os.path.join(self.save_dir, load_filename)
                if not os.path.isfile(load_path):
                    print("***", "fail find%s" % (load_path))
                    continue
                net = getattr(self, 'net' + name)
                if isinstance(net, torch.nn.DataParallel):
                    net = net.module
                print('loading the model from %s' % load_path)
                # if you are using PyTorch newer than 0.4 (e.g., built from
                # GitHub source), you can remove str() on self.device
                state_dict = torch.load(load_path,
                                        map_location=str(self.device))
                if hasattr(state_dict, '_metadata'):
                    del state_dict._metadata

                # patch InstanceNorm checkpoints prior to 0.4
                #for key in list(state_dict.keys()):  # need to copy keys here because we mutate in loop
                #    self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
                net.load_state_dict(state_dict)

        return current_epoch
예제 #28
0
파일: train.py 프로젝트: wlwkgus/GibbsNet
import time
from data.data_loader import get_data_loader
from models.models import create_model
from option_parser import TrainingOptionParser
from utils.visualizer import Visualizer

parser = TrainingOptionParser()
opt = parser.parse_args()

data_loader = get_data_loader(opt)

print("[INFO] batch size : {}".format(opt.batch_size))
print("[INFO] training batches : {}".format(len(data_loader)))

model = create_model(opt)
visualizer = Visualizer(opt)
total_steps = 0
epoch_count = 0

for epoch in range(opt.epoch):
    epoch_start_time = time.time()
    iter_count = 0

    for i, data in enumerate(data_loader):
        batch_start_time = time.time()
        total_steps += opt.batch_size
        iter_count += opt.batch_size
        # data : list
        model.set_input(data[0])
        model.optimize_parameters()
        batch_end_time = time.time()
예제 #29
0
    def __init__(self,
                 opt=None,
                 train_dt=None,
                 train_dt_warm=None,
                 dis_list=[],
                 val_dt_warm=None):
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.opt = opt

        self.visualizer = Visualizer(opt)

        num_gpus = torch.cuda.device_count()
        #dis_list[1]
        print(dis_list)
        #torch.cuda.device_count()
        self.rank = dis_list[0]
        print(self.rank)

        #=====START: ADDED FOR DISTRIBUTED======
        if num_gpus > 1:
            #init_distributed(rank, num_gpus, group_name, **dist_config)
            dist_config = dis_list[3]
            init_distributed(dis_list[0], dis_list[1], dis_list[2],
                             **dist_config)
        #=====END:   ADDED FOR DISTRIBUTED======

        if opt.ge_net == "srfeat":
            self.netG = model.G()
        elif opt.ge_net == "carn":
            self.netG = model.G1()
        elif opt.ge_net == "carnm":
            self.netG = model.G2()
        else:
            raise Exception("unknow ")

        self.netD_vgg = model.D(input_c=512, input_width=18)

        self.netD = model.D()

        if opt.vgg_type == "style":
            self.vgg = load_vgg16(opt.vgg_model_path + '/models')
        elif opt.vgg_type == "classify":
            self.vgg = model.vgg19_withoutbn_customefinetune()

        self.vgg.eval()
        for param in self.vgg.parameters():
            param.requires_grad = False

#         for p in self.vgg.parameters():
#             p.requires_grad = False

        init_weights(self.netD, init_type=opt.init)
        init_weights(self.netD_vgg, init_type=opt.init)
        init_weights(self.netG, init_type=opt.init)

        self.vgg = self.vgg.to(self.device)
        self.netD = self.netD.to(self.device)
        self.netD_vgg = self.netD_vgg.to(self.device)
        self.netG = self.netG.to(self.device)

        #=====START: ADDED FOR DISTRIBUTED======
        if num_gpus > 1:
            #self.vgg = apply_gradient_allreduce(self.vgg)
            self.netD_vgg = apply_gradient_allreduce(self.netD_vgg)
            self.netD = apply_gradient_allreduce(self.netD)
            self.netG = apply_gradient_allreduce(self.netG)

        #=====END:   ADDED FOR DISTRIBUTED======

        print(opt)

        self.optim_G= torch. optim.Adam(filter(lambda p: p.requires_grad, self.netG.parameters()),\
         lr=opt.warm_opt.lr, betas=opt.warm_opt.betas, weight_decay=0.0)

        #        self.optim_G= torch.optim.Adam(filter(lambda p: p.requires_grad, self.netG.parameters()),\
        #         lr=opt.gen.lr, betas=opt.gen.betas, weight_decay=0.0)

        if opt.dis.optim == "sgd":
            self.optim_D= torch.optim.SGD( filter(lambda p: p.requires_grad, \
                itertools.chain(self.netD_vgg.parameters(),self.netD.parameters() ) ),\
                lr=opt.dis.lr,
             )
        elif opt.dis.optim == "adam":
            self.optim_D= torch.optim.Adam( filter(lambda p: p.requires_grad, \
                itertools.chain(self.netD_vgg.parameters(),self.netD.parameters() ) ),\
                lr=opt.dis.lr,betas=opt.dis.betas, weight_decay=0.0
             )
        else:
            raise Exception("unknown")

        print("create schedule ")

        lr_sc_G = get_scheduler(self.optim_G, opt.gen)
        lr_sc_D = get_scheduler(self.optim_D, opt.dis)

        self.schedulers = []

        self.schedulers.append(lr_sc_G)
        self.schedulers.append(lr_sc_D)

        # =====START: ADDED FOR DISTRIBUTED======
        train_dt = torch.utils.data.ConcatDataset([train_dt, train_dt_warm])

        train_sampler = DistributedSampler(train_dt) if num_gpus > 1 else None
        val_sampler_warm = DistributedSampler(
            val_dt_warm) if num_gpus > 1 else None
        # =====END:   ADDED FOR DISTRIBUTED======

        kw = {
            "pin_memory": True,
            "num_workers": 8
        } if torch.cuda.is_available() else {}
        dl_c =t_data.DataLoader(train_dt ,batch_size=opt.batch_size,\
             sampler=train_sampler , drop_last=True, **kw )

        dl_val_warm = t_data.DataLoader(
            val_dt_warm,
            batch_size=opt.batch_size
            if not hasattr(opt, "batch_size_warm") else opt.batch_size_warm,
            sampler=val_sampler_warm,
            drop_last=True,
            **kw)

        self.dt_train = dl_c
        self.dt_val_warm = dl_val_warm

        if opt.warm_opt.loss_fn == "mse":
            self.critic_pixel = torch.nn.MSELoss()
        elif opt.warm_opt.loss_fn == "l1":
            self.critic_pixel = torch.nn.L1Loss()
        elif opt.warm_opt.loss_fn == "smooth_l1":
            self.critic_pixel = torch.nn.SmoothL1Loss()
        else:
            raise Exception("unknown")

        self.critic_pixel = self.critic_pixel.to(self.device)

        self.gan_loss = GANLoss(gan_mode=opt.gan_loss_fn).to(self.device)
        print("init ....")

        self.save_dir = os.path.dirname(self.visualizer.log_name)
예제 #30
0
def main():
    opt = fake_opt.JointTrain()
    device = torch.device("cuda:{}".format(opt.gpu_ids[0]) if len(opt.gpu_ids)
                          > 0 and torch.cuda.is_available() else "cpu")

    visualizer = Visualizer(opt)
    logging = visualizer.get_logger()
    acc_report = visualizer.add_plot_report(['train/acc', 'val/acc'],
                                            'acc.png')
    loss_report = visualizer.add_plot_report(
        ['train/loss', 'val/loss', 'train/enhance_loss', 'val/enhance_loss'],
        'loss.png')

    # data
    logging.info("Building dataset.")
    train_dataset = MixSequentialDataset(
        opt,
        os.path.join(opt.dataroot, 'train_new'),
        os.path.join(opt.dict_dir, 'train/vocab'),
    )
    val_dataset = MixSequentialDataset(
        opt,
        os.path.join(opt.dataroot, 'dev_new'),
        os.path.join(opt.dict_dir, 'train/vocab'),
    )
    train_sampler = BucketingSampler(train_dataset, batch_size=opt.batch_size)
    train_loader = MixSequentialDataLoader(train_dataset,
                                           num_workers=opt.num_workers,
                                           batch_sampler=train_sampler)
    val_loader = MixSequentialDataLoader(val_dataset,
                                         batch_size=int(opt.batch_size / 2),
                                         num_workers=opt.num_workers,
                                         shuffle=False)
    opt.idim = train_dataset.get_feat_size()
    opt.odim = train_dataset.get_num_classes()
    opt.char_list = train_dataset.get_char_list()
    opt.train_dataset_len = len(train_dataset)
    logging.info('#input dims : ' + str(opt.idim))
    logging.info('#output dims: ' + str(opt.odim))
    logging.info("Dataset ready!")

    # Setup an model
    lr = opt.lr
    eps = opt.eps
    iters = opt.iters
    best_acc = opt.best_acc
    best_loss = opt.best_loss
    start_epoch = opt.start_epoch

    enhance_model_path = None
    if opt.enhance_resume:
        #enhance_model_path = os.path.join(opt.works_dir, opt.enhance_resume)
        enhance_model_path = "/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/enhance_fbank_train_table_2/model.loss.best"
        if os.path.isfile(enhance_model_path):
            enhance_model = EnhanceModel.load_model(enhance_model_path,
                                                    'enhance_state_dict', opt)
        else:
            print("no checkpoint found at {}".format(enhance_model_path))

    asr_model_path = None
    if opt.asr_resume:
        #asr_model_path = os.path.join(opt.works_dir, opt.asr_resume)
        asr_model_path = "/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/asr_mix_train_table3_1/model.acc.best"
        if os.path.isfile(asr_model_path):
            #asr_model = ShareE2E.load_model(asr_model_path, 'asr_state_dict', opt)
            asr_model = E2E.load_model(asr_model_path, 'asr_state_dict', opt)
        else:
            print("no checkpoint found at {}".format(asr_model_path))

    joint_model_path = None
    if opt.joint_resume:
        joint_model_path = os.path.join(opt.works_dir, opt.joint_resume)
        if os.path.isfile(joint_model_path):
            package = torch.load(joint_model_path,
                                 map_location=lambda storage, loc: storage)
            lr = package.get('lr', opt.lr)
            eps = package.get('eps', opt.eps)
            best_acc = package.get('best_acc', 0)
            best_loss = package.get('best_loss', float('inf'))
            start_epoch = int(package.get('epoch', 0))
            iters = int(package.get('iters', 0)) - 1
            print('joint_model_path {} and iters {}'.format(
                joint_model_path, iters))
            ##loss_report = package.get('loss_report', loss_report)
            ##visualizer.set_plot_report(loss_report, 'loss.png')
        else:
            print("no checkpoint found at {}".format(joint_model_path))
    if joint_model_path is not None or enhance_model_path is None:
        enhance_model_path_with_gan = '/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/enhance_gan_train_both_enhance_cmvn/model.loss.best'
        enhance_model = EnhanceModel.load_model(enhance_model_path_with_gan,
                                                'enhance_state_dict', opt)
    if joint_model_path is not None or asr_model_path is None:
        #asr_model = ShareE2E.load_model(joint_model_path, 'asr_state_dict', opt)
        asr_model_path = '/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/asr_train/model.acc.best'
        asr_model = E2E.load_model(asr_model_path, 'asr_state_dict', opt)
    feat_model = FbankModel.load_model(joint_model_path, 'fbank_state_dict',
                                       opt)
    if opt.isGAN:
        gan_model = GANModel.load_model(enhance_model_path_with_gan,
                                        'gan_state_dict', opt)
    ##set_requires_grad([enhance_model], False)

    # Setup an optimizer
    enhance_parameters = filter(lambda p: p.requires_grad,
                                enhance_model.parameters())
    asr_parameters = filter(lambda p: p.requires_grad, asr_model.parameters())
    if opt.isGAN:
        gan_parameters = filter(lambda p: p.requires_grad,
                                gan_model.parameters())
    if opt.opt_type == 'adadelta':
        enhance_optimizer = torch.optim.Adadelta(enhance_parameters,
                                                 rho=0.95,
                                                 eps=eps)
        asr_optimizer = torch.optim.Adadelta(asr_parameters, rho=0.95, eps=eps)
        if opt.isGAN:
            gan_optimizer = torch.optim.Adadelta(gan_parameters,
                                                 rho=0.95,
                                                 eps=eps)
    elif opt.opt_type == 'adam':
        enhance_optimizer = torch.optim.Adam(enhance_parameters,
                                             lr=lr,
                                             betas=(opt.beta1, 0.999))
        asr_optimizer = torch.optim.Adam(asr_parameters,
                                         lr=lr,
                                         betas=(opt.beta1, 0.999))
        if opt.isGAN:
            gan_optimizer = torch.optim.Adam(gan_parameters,
                                             lr=lr,
                                             betas=(opt.beta1, 0.999))
    if opt.isGAN:
        criterionGAN = GANLoss(use_lsgan=not opt.no_lsgan).to(device)

    # Training
    #enhance_cmvn_path = '/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/joint_train/enhance_cmvn.npy'

    enhance_cmvn_path = None
    if enhance_cmvn_path:
        enhance_cmvn = np.load(enhance_cmvn_path)
        enhance_cmvn = torch.FloatTensor(enhance_cmvn)
    else:
        enhance_cmvn = compute_cmvn_epoch(opt, train_loader, enhance_model,
                                          feat_model)
    sample_rampup = utils.ScheSampleRampup(opt.sche_samp_start_iter,
                                           opt.sche_samp_final_iter,
                                           opt.sche_samp_final_rate)
    sche_samp_rate = sample_rampup.update(iters)

    fbank_cmvn_file = os.path.join(opt.exp_path, 'fbank_cmvn.npy')
    fbank_cmvn = np.load(fbank_cmvn_file)
    fbank_cmvn = torch.FloatTensor(fbank_cmvn)

    enhance_model.train()
    feat_model.train()
    asr_model.train()
    for epoch in range(start_epoch, opt.epochs):
        if epoch > opt.shuffle_epoch:
            print("Shuffling batches for the following epochs")
            train_sampler.shuffle(epoch)
        for i, (data) in enumerate(train_loader, start=0):
            utt_ids, spk_ids, clean_inputs, clean_log_inputs, mix_inputs, mix_log_inputs, cos_angles, targets, input_sizes, target_sizes = data
            enhance_out = enhance_model(mix_inputs, mix_log_inputs,
                                        input_sizes)
            enhance_feat = feat_model(enhance_out)
            clean_feat = feat_model(clean_inputs)
            mix_feat = feat_model(mix_inputs)
            if opt.enhance_loss_type == 'L2':
                enhance_loss = F.mse_loss(enhance_feat, clean_feat.detach())
            elif opt.enhance_loss_type == 'L1':
                enhance_loss = F.l1_loss(enhance_feat, clean_feat.detach())
            elif opt.enhance_loss_type == 'smooth_L1':
                enhance_loss = F.smooth_l1_loss(enhance_feat,
                                                clean_feat.detach())
            enhance_loss = opt.enhance_loss_lambda * enhance_loss
            enhance_feature = feat_model(enhance_out, enhance_cmvn)
            clean_feature = feat_model(clean_inputs, fbank_cmvn)
            loss_ctc, loss_att, acc = asr_model(enhance_feature, targets,
                                                input_sizes, target_sizes,
                                                sche_samp_rate)

            #loss_ctc, loss_att, acc, clean_context, mix_context = asr_model(clean_feat, enhance_feat, targets, input_sizes, target_sizes, sche_samp_rate, enhance_cmvn)
            #coral_loss = opt.coral_loss_lambda * CORAL(clean_context, mix_context)
            coral_loss = 0
            asr_loss = opt.mtlalpha * loss_ctc + (1 - opt.mtlalpha) * loss_att
            loss = asr_loss + enhance_loss + coral_loss
            #loss = asr_loss

            if opt.isGAN:
                set_requires_grad([gan_model], False)
                if opt.netD_type == 'pixel':
                    fake_AB = torch.cat((mix_feat, enhance_feat), 2)
                else:
                    fake_AB = enhance_feature
                gan_loss = opt.gan_loss_lambda * criterionGAN(
                    gan_model(fake_AB), True)
                loss += gan_loss
            set_requires_grad([enhance_model], False)
            enhance_optimizer.zero_grad()
            asr_optimizer.zero_grad()  # Clear the parameter gradients
            loss.backward()
            # compute the gradient norm to check if it is normal or not
            grad_norm = torch.nn.utils.clip_grad_norm_(asr_model.parameters(),
                                                       opt.grad_clip)
            if math.isnan(grad_norm):
                logging.warning('grad norm is nan. Do not update model.')
            else:
                enhance_optimizer.step()
                asr_optimizer.step()

            if opt.isGAN:
                set_requires_grad([gan_model], True)
                gan_optimizer.zero_grad()
                if opt.netD_type == 'pixel':
                    fake_AB = torch.cat((mix_feat, enhance_feat), 2)
                    real_AB = torch.cat((mix_feat, clean_feat), 2)
                else:
                    fake_AB = enhance_feature
                    real_AB = clean_feature
                loss_D_real = criterionGAN(gan_model(real_AB.detach()), True)
                loss_D_fake = criterionGAN(gan_model(fake_AB.detach()), False)
                loss_D = (loss_D_real + loss_D_fake) * 0.5
                loss_D.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    gan_model.parameters(), opt.grad_clip)
                if math.isnan(grad_norm):
                    logging.warning('grad norm is nan. Do not update model.')
                else:
                    gan_optimizer.step()

            iters += 1
            errors = {
                'train/loss': loss.item(),
                'train/loss_ctc': loss_ctc.item(),
                'train/acc': acc,
                'train/loss_att': loss_att.item(),
                'train/enhance_loss': enhance_loss.item()
            }
            if opt.isGAN:
                errors['train/loss_D'] = loss_D.item()
                errors['train/gan_loss'] = opt.gan_loss_lambda * gan_loss.item(
                )

            visualizer.set_current_errors(errors)
            if iters % opt.print_freq == 0:
                visualizer.print_current_errors(epoch, iters)
                state = {
                    'asr_state_dict': asr_model.state_dict(),
                    'fbank_state_dict': feat_model.state_dict(),
                    'enhance_state_dict': enhance_model.state_dict(),
                    'opt': opt,
                    'epoch': epoch,
                    'iters': iters,
                    'eps': opt.eps,
                    'lr': opt.lr,
                    'best_loss': best_loss,
                    'best_acc': best_acc,
                    'acc_report': acc_report,
                    'loss_report': loss_report
                }
                if opt.isGAN:
                    state['gan_state_dict'] = gan_model.state_dict()
                filename = 'latest'
                utils.save_checkpoint(state, opt.exp_path, filename=filename)

            if iters % opt.validate_freq == 0:
                sche_samp_rate = sample_rampup.update(iters)
                print("iters {} sche_samp_rate {}".format(
                    iters, sche_samp_rate))
                enhance_model.eval()
                feat_model.eval()
                asr_model.eval()
                torch.set_grad_enabled(False)
                num_saved_attention = 0
                for i, (data) in tqdm(enumerate(val_loader, start=0)):
                    utt_ids, spk_ids, clean_inputs, clean_log_inputs, mix_inputs, mix_log_inputs, cos_angles, targets, input_sizes, target_sizes = data
                    enhance_out = enhance_model(mix_inputs, mix_log_inputs,
                                                input_sizes)
                    enhance_feat = feat_model(enhance_out)
                    clean_feat = feat_model(clean_inputs)
                    mix_feat = feat_model(mix_inputs)
                    clean_feat_val = feat_model(clean_inputs, fbank_cmvn)
                    enhance_feat_val = feat_model(enhance_out, enhance_cmvn)
                    if opt.enhance_loss_type == 'L2':
                        enhance_loss = F.mse_loss(enhance_feat,
                                                  clean_feat.detach())
                    elif opt.enhance_loss_type == 'L1':
                        enhance_loss = F.l1_loss(enhance_feat,
                                                 clean_feat.detach())
                    elif opt.enhance_loss_type == 'smooth_L1':
                        enhance_loss = F.smooth_l1_loss(
                            enhance_feat, clean_feat.detach())
                    if opt.isGAN:
                        set_requires_grad([gan_model], False)
                        if opt.netD_type == 'pixel':
                            fake_AB = torch.cat((mix_feat, enhance_feat), 2)
                        else:
                            fake_AB = enhance_feat_val
                        gan_loss = criterionGAN(gan_model(fake_AB), True)
                        enhance_loss += opt.gan_loss_lambda * gan_loss

                    #loss_ctc, loss_att, acc, clean_context, mix_context = asr_model(clean_feat, enhance_feat, targets, input_sizes, target_sizes, 0.0, enhance_cmvn)
                    loss_ctc, loss_att, acc = asr_model(
                        enhance_feat_val, targets, input_sizes, target_sizes,
                        sche_samp_rate)

                    asr_loss = opt.mtlalpha * loss_ctc + (
                        1 - opt.mtlalpha) * loss_att
                    enhance_loss = opt.enhance_loss_lambda * enhance_loss
                    loss = asr_loss + enhance_loss
                    errors = {
                        'val/loss': loss.item(),
                        'val/loss_ctc': loss_ctc.item(),
                        'val/acc': acc,
                        'val/loss_att': loss_att.item(),
                        'val/enhance_loss': enhance_loss.item()
                    }
                    if opt.isGAN:
                        errors[
                            'val/gan_loss'] = opt.gan_loss_lambda * gan_loss.item(
                            )
                    visualizer.set_current_errors(errors)

                    if opt.num_save_attention > 0 and opt.mtlalpha != 1.0:
                        if num_saved_attention < opt.num_save_attention:
                            att_ws = asr_model.calculate_all_attentions(
                                enhance_feat_val, targets, input_sizes,
                                target_sizes)
                            for x in range(len(utt_ids)):
                                att_w = att_ws[x]
                                utt_id = utt_ids[x]
                                file_name = "{}_ep{}_it{}.png".format(
                                    utt_id, epoch, iters)
                                dec_len = int(target_sizes[x])
                                enc_len = int(input_sizes[x])
                                visualizer.plot_attention(
                                    att_w, dec_len, enc_len, file_name)
                                num_saved_attention += 1
                                if num_saved_attention >= opt.num_save_attention:
                                    break
                enhance_model.train()
                feat_model.train()
                asr_model.train()
                torch.set_grad_enabled(True)

                visualizer.print_epoch_errors(epoch, iters)
                acc_report = visualizer.plot_epoch_errors(
                    epoch, iters, 'acc.png')
                loss_report = visualizer.plot_epoch_errors(
                    epoch, iters, 'loss.png')
                val_loss = visualizer.get_current_errors('val/loss')
                val_acc = visualizer.get_current_errors('val/acc')
                filename = None
                if opt.criterion == 'acc' and opt.mtl_mode is not 'ctc':
                    if val_acc < best_acc:
                        logging.info('val_acc {} > best_acc {}'.format(
                            val_acc, best_acc))
                        opt.eps = utils.adadelta_eps_decay(
                            asr_optimizer, opt.eps_decay)
                    else:
                        filename = 'model.acc.best'
                    best_acc = max(best_acc, val_acc)
                    logging.info('best_acc {}'.format(best_acc))
                elif opt.criterion == 'loss':
                    if val_loss > best_loss:
                        logging.info('val_loss {} > best_loss {}'.format(
                            val_loss, best_loss))
                        opt.eps = utils.adadelta_eps_decay(
                            asr_optimizer, opt.eps_decay)
                    else:
                        filename = 'model.loss.best'
                    best_loss = min(val_loss, best_loss)
                    logging.info('best_loss {}'.format(best_loss))
                state = {
                    'asr_state_dict': asr_model.state_dict(),
                    'fbank_state_dict': feat_model.state_dict(),
                    'enhance_state_dict': enhance_model.state_dict(),
                    'opt': opt,
                    'epoch': epoch,
                    'iters': iters,
                    'eps': opt.eps,
                    'lr': opt.lr,
                    'best_loss': best_loss,
                    'best_acc': best_acc,
                    'acc_report': acc_report,
                    'loss_report': loss_report
                }
                if opt.isGAN:
                    state['gan_state_dict'] = gan_model.state_dict()
                utils.save_checkpoint(state, opt.exp_path, filename=filename)
                visualizer.reset()
                enhance_cmvn = compute_cmvn_epoch(opt, train_loader,
                                                  enhance_model, feat_model)
예제 #31
0
import time
from data.data_loader import get_data_loader
from models.models import create_model
from option_parser import TrainingOptionParser
from utils.visualizer import Visualizer

parser = TrainingOptionParser()
opt = parser.parse_args()

data_loader = get_data_loader(opt)

print("[INFO] batch size : {}".format(opt.batch_size))
print("[INFO] training batches : {}".format(len(data_loader)))

model = create_model(opt)
visualizer = Visualizer(opt)
total_steps = 0
epoch_count = 0

for epoch in range(opt.epoch):
    epoch_start_time = time.time()
    iter_count = 0

    for i, data in enumerate(data_loader):
        batch_start_time = time.time()
        total_steps += opt.batch_size
        iter_count += opt.batch_size
        # data : list
        model.set_input(data[0])
        model.optimize_parameters()
        batch_end_time = time.time()