コード例 #1
0
def train():
    opt._parse()
    vis_tool = Visualizer(env=opt.env)

    print('load data')
    train_dataset = Datatrain(opt.rootpath, mode="train/")
    val_dataset = Dataval(opt.rootpath, mode="val/")

    trainer = Ftrainer(opt, image_size=opt.image_size)
    if opt.load_G:
        trainer.load_G(opt.load_G)
    print('model G construct completed')

    if opt.load_F:
        trainer.load_F(opt.load_F)
        print('model F construct completed')

    best_map = 0.0
    for epoch in range(opt.epoch):
        trainer.train()
        train_dataloader = data_.DataLoader(train_dataset,
                                            batch_size=opt.train_batch_size,
                                            shuffle=True,
                                            num_workers=opt.num_workers)
        val_dataloader = data_.DataLoader(val_dataset,
                                          batch_size=opt.test_batch_size,
                                          num_workers=opt.num_workers,
                                          shuffle=False)

        # test_model(test_dataloader, trainer, epoch, ifsave=True, test_num=opt.test_num)
        for ii, (img, oriimg, mask) in tqdm(enumerate(train_dataloader),
                                            total=len(train_dataloader)):
            loss, loss1, loss2 = trainer.train_onebatch(img, oriimg, mask)
            if ii % 20 == 0:
                trainer.eval()
                vis_tool.plot("totalloss", loss.detach().cpu().numpy())
                vis_tool.plot("loss_r", loss1.detach().cpu().numpy())
                vis_tool.plot("loss_t", loss2.detach().cpu().numpy())
                snr, output, edg, edg2 = trainer(img[0:2, :, :, :],
                                                 oriimg[0:2, :, :, :],
                                                 mask[0:2, :, :, :],
                                                 vis=True)
                vis_tool.plot("snr_train", snr)
                input = img[0][0].numpy()
                input = (input * 255).astype(np.uint8)
                vis_tool.img("input", input)
                label = oriimg[0][0].numpy()
                label = (label * 255).astype(np.uint8)
                vis_tool.img("label", label)
                snr = round(snr, 2)
                vis_pic(output, snr, vis_tool)
                vis_tool.img("predict_segm", edg[0])
                vis_tool.img("ori_segm", edg2[0])
                trainer.train()

        ifsave = False
        if (epoch + 1) % 10 == 0:
            ifsave = True
        eval_result = test_model(val_dataloader,
                                 trainer,
                                 epoch,
                                 ifsave=ifsave,
                                 test_num=opt.test_num)
        print('eval_loss: ', eval_result)

        vis_tool.plot("SNR_val", eval_result["SNR"])
        if epoch > 100 and eval_result["SNR"] > best_map:
            best_map = eval_result["SNR"]
            best_path = trainer.save_F(best_map=best_map)
            print("save to %s !" % best_path)
コード例 #2
0
def train(**kwargs):
    opt._parse(kwargs)

    device = t.device('cuda') if opt.use_gpu else t.device('cpu')
    vis = Visualizer(opt.env)

    # Data loading
    transfroms = tv.transforms.Compose([
        tv.transforms.Resize(opt.image_size),
        tv.transforms.CenterCrop(opt.image_size),
        tv.transforms.ToTensor(),
        tv.transforms.Lambda(lambda x: x * 255)
    ])
    dataset = tv.datasets.ImageFolder(opt.data_root, transfroms)
    dataloader = data.DataLoader(dataset, opt.batch_size)

    # style transformer network
    transformer = TransformerNet()
    if opt.model_path:
        transformer.load_state_dict(
            t.load(opt.model_path, map_location=lambda _s, _: _s))
    transformer.to(device)

    # Vgg16 for Perceptual Loss
    vgg = Vgg16().eval()
    vgg.to(device)
    for param in vgg.parameters():
        param.requires_grad = False

    # Optimizer: use Adam
    optimizer = t.optim.Adam(transformer.parameters(), opt.lr)

    # Get style image
    style = utils.get_style_data(opt.style_path)
    vis.img('style', (style.data[0] * 0.225 + 0.45).clamp(min=0, max=1))
    style = style.to(device)

    # print("style.shape: ", style.shape)

    # gram matrix for style image
    with t.no_grad():
        features_style = vgg(style)
        gram_style = [utils.gram_matrix(y) for y in features_style]

    # Loss meter
    style_meter = tnt.meter.AverageValueMeter()
    content_meter = tnt.meter.AverageValueMeter()

    for epoch in range(opt.epoches):
        content_meter.reset()
        style_meter.reset()

        for ii, (x, _) in tqdm.tqdm(enumerate(dataloader)):

            # Train
            optimizer.zero_grad()
            x = x.to(device)
            y = transformer(x)
            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)
            features_y = vgg(y)
            features_x = vgg(x)

            # content loss
            content_loss = opt.content_weight * F.mse_loss(
                features_y.relu2_2, features_x.relu2_2)

            # style loss
            style_loss = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gram_y = utils.gram_matrix(ft_y)
                style_loss += F.mse_loss(gram_y, gm_s.expand_as(gram_y))
            style_loss *= opt.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            # Loss smooth for visualization
            content_meter.add(content_loss.item())
            style_meter.add(style_loss.item())

            if (ii + 1) % opt.plot_every == 0:
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()

                # visualization
                vis.plot('content_loss', content_meter.value()[0])
                vis.plot('style_loss', style_meter.value()[0])
                # denorm input/output, since we have applied (utils.normalize_batch)
                vis.img('output',
                        (y.data.cpu()[0] * 0.225 + 0.45).clamp(min=0, max=1))
                vis.img('input', (x.data.cpu()[0] * 0.225 + 0.45).clamp(min=0,
                                                                        max=1))

        # save checkpoint
        vis.save([opt.env])
        t.save(transformer.state_dict(), 'checkpoints/%s_style.pth' % epoch)
コード例 #3
0
                                         dim=0).contiguous()
                labels = torch.cat((visual_graph.y, points_graph.y),
                                   dim=0).contiguous()
                iloss, _, _, _, _, _ = global_loss(trip_loss, global_feats,
                                                   labels)
                batch_loss += iloss
            batch_loss = batch_loss / len(value[0])
            optimizer.zero_grad()
            batch_loss.backward()
            # clip_gradient(model, 10)
            optimizer.step()

            batch_time = time.time() - end_time
            avg.update(batch_time=batch_time, loss=batch_loss.item())
            if (step + 1) % args.disp_interval == 0:
                vis.plot('loss', avg.avg('loss'))
                log_str = '(Train) Epoch: [{0}][{1}/{2}]\t lr: {lr:.6f} \t {batch_time:s} \t {loss:s} \n'.format(
                    epoch,
                    step + 1,
                    len(train_dataloader),
                    lr=lr,
                    batch_time=avg.batch_time,
                    loss=avg.loss)
                vis.log(log_str)
        if args.trainval:
            # validation
            model.eval()
            valid_avg = AverageMeter()
            valid_disp_interval = int(args.disp_interval /
                                      len(train_dataloader) *
                                      len(valid_dataloader))