Example #1
0
    def __init__(self, hparams):
        super(GAN, self).__init__()
        self.hparams = hparams
        batch_size = self.hparams.batch_size
        # networks
        # mnist_shape = (1, 28, 28)
        self.generator_x2ct = Generator(batch_size)
        self.generator_ct2x = Vgg16()
        self.discriminator_ct = CT_Discriminator(batch_size)
        self.discriminator_x = X_Discriminator(batch_size)

        # cache for generated images
        self.generated_imgs = None
        self.last_imgs = None
Example #2
0
def main():

    use_cuda = torch.cuda.is_available() and not args.no_cuda
    device = torch.device('cuda' if use_cuda else 'cpu')
    print(device)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if use_cuda:
        torch.cuda.manual_seed(args.seed)
        torch.backends.cudnn.deterministic = True

    rgb = False
    if args.mode == 'rgb':
        rgb = True

    if args.gray_scale:
        rgb = False

    if args.tracking_data_mod is True:
        args.input_size = 192

    # DATALOADER

    train_dataset = GesturesDataset(model=args.model, csv_path='csv_dataset', train=True, mode=args.mode, rgb=rgb,
                                    normalization_type=1,
                                    n_frames=args.n_frames, resize_dim=args.input_size,
                                    transform_train=args.train_transforms, tracking_data_mod=args.tracking_data_mod)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.n_workers)

    validation_dataset = GesturesDataset(model=args.model, csv_path='csv_dataset', train=False, mode=args.mode, rgb=rgb, normalization_type=1,
                                   n_frames=args.n_frames, resize_dim=args.input_size, tracking_data_mod=args.tracking_data_mod)
    validation_loader = DataLoader(validation_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_workers)

    # paramteri per la rete

    in_channels = args.n_frames if not rgb else args.n_frames * 3
    n_classes = args.n_classes

    if args.model == 'LeNet':
        model = LeNet(input_channels=in_channels, input_size=args.input_size, n_classes=n_classes).to(device)

    elif args.model == 'AlexNet':
        model = AlexNet(input_channels=in_channels, input_size=args.input_size, n_classes=n_classes).to(device)

    elif args.model == 'AlexNetBN':
        model = AlexNetBN(input_channels=in_channels, input_size=args.input_size, n_classes=n_classes).to(device)

    elif args.model == "Vgg16":
        model = Vgg16(input_channels=in_channels, input_size=args.input_size, n_classes=n_classes).to(device)

    elif args.model == "Vgg16P":
        model = models.vgg16(pretrained=args.pretrained)
        for params in model.parameters():
            params.requires_grad = False
        model.features._modules['0'] = nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=(3, 3), stride=1, padding=1)
        model.classifier._modules['6'] = nn.Linear(4096, n_classes)
        # model.fc = torch.nn.Linear(model.fc.in_features, n_classes)
        model = model.to(device)

    elif args.model == "ResNet18P":
        model = models.resnet18(pretrained=args.pretrained)
        for params in model.parameters():
            params.requires_grad = False
        model._modules['conv1'] = nn.Conv2d(in_channels, 64, 7, stride=2, padding=3)
        model.fc = torch.nn.Linear(model.fc.in_features, n_classes)
        model = model.to(device)

    elif args.model == "ResNet34P":
        model = models.resnet34(pretrained=args.pretrained)
        for params in model.parameters():
            params.requires_grad = False
        model._modules['conv1'] = nn.Conv2d(in_channels, 64, 7, stride=2, padding=3)
        model.fc = torch.nn.Linear(model.fc.in_features, n_classes)
        model = model.to(device)

    elif args.model == "DenseNet121P":
        model = models.densenet121(pretrained=args.pretrained)
        for params in model.parameters():
            params.requires_grad = False
        model.features._modules['conv0'] = nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=(7, 7),
                                                     stride=(2, 2), padding=(3, 3))
        model.classifier = nn.Linear(in_features=1024, out_features=n_classes, bias=True)
        model = model.to(device)

    elif args.model == "DenseNet161P":
        model = models.densenet161(pretrained=args.pretrained)
        # for params in model.parameters():
        #     params.requires_grad = False
        model.features._modules['conv0'] = nn.Conv2d(in_channels=in_channels, out_channels=96, kernel_size=(7, 7),
                                                     stride=(2, 2), padding=(3, 3))
        model.classifier = nn.Linear(in_features=2208, out_features=n_classes, bias=True)
        model = model.to(device)

    elif args.model == "DenseNet169P":
        model = models.densenet169(pretrained=args.pretrained)
        for params in model.parameters():
            params.requires_grad = False
        model.features._modules['conv0'] = nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=(7, 7),
                                                     stride=(2, 2), padding=(3, 3))
        model.classifier = nn.Linear(in_features=1664, out_features=n_classes, bias=True)
        model = model.to(device)

    elif args.model == "DenseNet201P":
        model = models.densenet201(pretrained=args.pretrained)
        for params in model.parameters():
            params.requires_grad = False
        model.features._modules['conv0'] = nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=(7, 7),
                                                     stride=(2, 2), padding=(3, 3))
        model.classifier = nn.Linear(in_features=1920, out_features=n_classes, bias=True)
        model = model.to(device)
    # RNN
    elif args.model == 'LSTM' or args.model == 'GRU':
        model = Rnn(rnn_type=args.model, input_size=args.input_size, hidden_size=args.hidden_size,
                    batch_size=args.batch_size,
                    num_classes=args.n_classes, num_layers=args.n_layers,
                    final_layer=args.final_layer).to(device)
    # C3D

    elif args.model == 'C3D':
        if args.pretrained:
            model = C3D(rgb=rgb, num_classes=args.n_classes)


            # modifico parametri
            print('ok')

            model.load_state_dict(torch.load('c3d_weights/c3d.pickle', map_location=device), strict=False)
            # # for params in model.parameters():
            #     # params.requires_grad = False

            model.conv1 = nn.Conv3d(1 if not rgb else 3, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1))
            # tolgo fc6 perchè 30 frames
            model.fc6 = nn.Linear(16384, 4096)  # num classes 28672 (112*200)
            model.fc7 = nn.Linear(4096, 4096)  # num classes
            model.fc8 = nn.Linear(4096, n_classes)  # num classes

            model = model.to(device)


    # Conv-lstm
    elif args.model == 'Conv-lstm':
        model = ConvLSTM(input_size=(args.input_size, args.input_size),
                         input_dim=1 if not rgb else 3,
                         hidden_dim=[64, 64, 128],
                         kernel_size=(3, 3),
                         num_layers=args.n_layers,
                         batch_first=True,
                         ).to(device)
    elif args.model == 'DeepConvLstm':
        model = DeepConvLstm(input_channels_conv=1 if not rgb else 3, input_size_conv=args.input_size, n_classes=12,
                             n_frames=args.n_frames, batch_size=args.batch_size).to(device)

    elif args.model == 'ConvGRU':
        model = ConvGRU(input_size=40, hidden_sizes=[64, 128],
                        kernel_sizes=[3, 3], n_layers=2).to(device)

    else:
        raise NotImplementedError

    if args.opt == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
        # optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=args.momentum)

    elif args.opt == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    loss_function = nn.CrossEntropyLoss().to(device)

    start_epoch = 0
    if args.resume:
        checkpoint = torch.load("/projects/fabio/weights/gesture_recog_weights/checkpoint{}.pth.tar".format(args.model))
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']

        print("Resuming state:\n-epoch: {}\n{}".format(start_epoch, model))

    #name experiment
    personal_name = "{}_{}_{}".format(args.model, args.mode, args.exp_name)
    info_experiment = "{}".format(personal_name)
    log_dir = "/projects/fabio/logs/gesture_recog_logs/exps"
    weight_dir = personal_name
    log_file = open("{}/{}.txt".format("/projects/fabio/logs/gesture_recog_logs/txt_logs", personal_name), 'w')
    log_file.write(personal_name + "\n\n")
    if personal_name:
        exp_name = (("exp_{}_{}".format(time.strftime("%c"), personal_name)).replace(" ", "_")).replace(":", "-")
    else:
        exp_name = (("exp_{}".format(time.strftime("%c"), personal_name)).replace(" ", "_")).replace(":", "-")
    writer = SummaryWriter("{}".format(os.path.join(log_dir, exp_name)))

    # add info experiment
    writer.add_text('Info experiment',
                    "model:{}"
                    "\n\npretrained:{}"
                    "\n\nbatch_size:{}"
                    "\n\nepochs:{}"
                    "\n\noptimizer:{}"
                    "\n\nlr:{}"
                    "\n\ndn_lr:{}"
                    "\n\nmomentum:{}"
                    "\n\nweight_decay:{}"
                    "\n\nn_frames:{}"
                    "\n\ninput_size:{}"
                    "\n\nhidden_size:{}"
                    "\n\ntracking_data_mode:{}"
                    "\n\nn_classes:{}"
                    "\n\nmode:{}"
                    "\n\nn_workers:{}"
                    "\n\nseed:{}"
                    "\n\ninfo:{}"
                    "".format(args.model, args.pretrained, args.batch_size, args.epochs, args.opt, args.lr, args.dn_lr, args.momentum,
                              args.weight_decay, args.n_frames, args.input_size, args.hidden_size, args.tracking_data_mod,
                              args.n_classes, args.mode, args.n_workers, args.seed, info_experiment))

    trainer = Trainer(model=model, loss_function=loss_function, optimizer=optimizer, train_loader=train_loader,
                      validation_loader=validation_loader,
                      batch_size=args.batch_size, initial_lr=args.lr,  device=device, writer=writer, personal_name=personal_name, log_file=log_file,
                      weight_dir=weight_dir, dynamic_lr=args.dn_lr)


    print("experiment: {}".format(personal_name))
    start = time.time()
    for ep in range(start_epoch, args.epochs):
        trainer.train(ep)
        trainer.val(ep)

    # display classes results
    classes = ['g0', 'g1', 'g2', 'g3', 'g4', 'g5', 'g6', 'g7', 'g8', 'g9', 'g10', 'g11']
    for i in range(args.n_classes):
        print('Accuracy of {} : {:.3f}%%'.format(
            classes[i], 100 * trainer.class_correct[i] / trainer.class_total[i]))

    end = time.time()
    h, rem = divmod(end - start, 3600)
    m, s, = divmod(rem, 60)
    print("\nelapsed time (ep.{}):{:0>2}:{:0>2}:{:05.2f}".format(args.epochs, int(h), int(m), s))


    # writing accuracy on file

    log_file.write("\n\n")
    for i in range(args.n_classes):
        log_file.write('Accuracy of {} : {:.3f}%\n'.format(
            classes[i], 100 * trainer.class_correct[i] / trainer.class_total[i]))
    log_file.close()
Example #3
0
import cv2
import time
from models import Xception, Vgg16, Resnet50, InceptionV3, InceptionResNetV2, MobileNet


if __name__ == "__main__":
  model = Vgg16()
  cap = cv2.VideoCapture(0)
  while(True):
    start = time.process_time()

    ret, frame = cap.read()
    read_time = time.process_time()

    x = model.preprocess(frame)
    preprocess_time = time.process_time()    

    model.predict(x)
    predict_time = time.process_time()

    print("read time: ", read_time - start)
    print("pre process time: ", preprocess_time - read_time)
    print("predict time: ", predict_time - preprocess_time)

    frame_rate = 1 / (predict_time - preprocess_time)
    print("frame rate: ", frame_rate)
    if cv2.waitKey(1) & 0xFF == ord('q'):
      break
  cap.release()
  cv2.destroyAllWindows()
Example #4
0
def train(**kwargs):
    # step1:config
    opt.parse(**kwargs)
    vis = Visualizer(opt.env)
    device = t.device('cuda') if opt.use_gpu else t.device('cpu')
    
    # step2:data
    # dataloader, style_img
    # 这次图片的处理和之前不一样,之前都是normalize,这次改成了lambda表达式乘以255,这种转化之后要给出一个合理的解释
    # 图片共分为两种,一种是原图,一种是风格图片,在作者的代码里,原图用于训练,需要很多,风格图片需要一张,用于损失函数
    
    transforms = T.Compose([
        T.Resize(opt.image_size),
        T.CenterCrop(opt.image_size),
        T.ToTensor(),
        T.Lambda(lambda x: x*255)    
    ])
    # 这次获取图片的方式和第七章一样,仍然是ImageFolder的方式,而不是dataset的方式
    dataset = tv.datasets.ImageFolder(opt.data_root,transform=transforms)
    dataloader = DataLoader(dataset,batch_size=opt.batch_size,shuffle=True,num_workers=opt.num_workers,drop_last=True)
    
    style_img = get_style_data(opt.style_path) # 1*c*H*W
    style_img = style_img.to(device)
    vis.img('style_image',(style_img.data[0]*0.225+0.45).clamp(min=0,max=1)) # 个人觉得这个没必要,下次可以实验一下
    
    # step3: model:Transformer_net 和 损失网络vgg16
    # 整个模型分为两部分,一部分是转化模型TransformerNet,用于转化原始图片,一部分是损失模型Vgg16,用于评价损失函数,
    # 在这里需要注意一下,Vgg16只是用于评价损失函数的,所以它的参数不参与反向传播,只有Transformer的参数参与反向传播,
    # 也就意味着,我们只训练TransformerNet,只保存TransformerNet的参数,Vgg16的参数是在网络设计时就已经加载进去的。
    # Vgg16是以验证model.eval()的方式在运行,表示其中涉及到pooling等层会发生改变
    # 那模型什么时候开始model.eval()呢,之前是是val和test中就会这样设置,那么Vgg16的设置理由是什么?
    # 这里加载模型的时候,作者使用了简单的map_location的记录方法,更轻巧一些
    # 发现作者在写这些的时候越来越趋向方便的方式
    # 在cuda的使用上,模型的cuda是直接使用的,而数据的cuda是在正式训练的时候才使用的,注意一下两者的区别
    # 在第七章作者是通过两种方式实现网络分离的,一种是对于前面网络netg,进行 fake_img = netg(noises).detach(),使得非叶子节点变成一个类似不需要邱求导的叶子节点
    # 第四章还需要重新看,
    
    transformer_net = TransformerNet()
    
    if opt.model_path:
        transformer_net.load_state_dict(t.load(opt.model_path,map_location= lambda _s, _: _s))    
    transformer_net.to(device)
    

    
    # step3: criterion and optimizer
    optimizer = t.optim.Adam(transformer_net.parameters(),opt.lr)
    # 此通过vgg16实现的,损失函数包含两个Gram矩阵和均方误差,所以,此外,我们还需要求Gram矩阵和均方误差
    vgg16 = Vgg16().eval() # 待验证
    vgg16.to(device)
    # vgg的参数不需要倒数,但仍然需要反向传播
    # 回头重新考虑一下detach和requires_grad的区别
    for param in vgg16.parameters():
        param.requires_grad = False
    criterion = t.nn.MSELoss(reduce=True, size_average=True)
    
    
    # step4: meter 损失统计
    style_meter = meter.AverageValueMeter()
    content_meter = meter.AverageValueMeter()
    total_meter = meter.AverageValueMeter()
    
    # step5.2:loss 补充
    # 求style_image的gram矩阵
    # gram_style:list [relu1_2,relu2_2,relu3_3,relu4_3] 每一个是b*c*c大小的tensor
    with t.no_grad():
        features = vgg16(style_img)
        gram_style = [gram_matrix(feature) for feature in features]
    # 损失网络 Vgg16
    # step5: train
    for epoch in range(opt.epoches):
        style_meter.reset()
        content_meter.reset()
        
        # step5.1: train
        for ii,(data,_) in tqdm(enumerate(dataloader)):
            optimizer.zero_grad()
            # 这里作者没有进行 Variable(),与之前不同
            # pytorch 0.4.之后tensor和Variable不再严格区分,创建的tensor就是variable
            # https://mp.weixin.qq.com/s?__biz=MzI0ODcxODk5OA==&mid=2247494701&idx=2&sn=ea8411d66038f172a2f553770adccbec&chksm=e99edfd4dee956c23c47c7bb97a31ee816eb3a0404466c1a57c12948d807c975053e38b18097&scene=21#wechat_redirect
            data = data.to(device)
            y = transformer_net(data)
            # vgg对输入的图片需要进行归一化
            data = normalize_batch(data)
            y = normalize_batch(y)

           
            feature_data = vgg16(data)
            feature_y = vgg16(y) 
            # 疑问??现在的feature是一个什么样子的向量?
            
            # step5.2: loss:content loss and style loss
            # content_loss
            # 在这里和书上的讲的不一样,书上是relu3_3,代码用的是relu2_2
            # https://blog.csdn.net/zhangxb35/article/details/72464152?utm_source=itdadao&utm_medium=referral
            # 均方误差指的是一个像素点的损失,可以理解N*b*h*w个元素加起来,然后除以N*b*h*w
            # 随机梯度下降法本身就是对batch内loss求平均后反向传播
            content_loss = opt.content_weight*criterion(feature_y.relu2_2,feature_data.relu2_2)
            # style loss
            # style loss:relu1_2,relu2_2,relu3_3,relu3_4 
            # 此时需要求每一张图片的gram矩阵
            
            style_loss = 0
            # tensor也可以 for i in tensor:,此时只拆解外面一层的tensor
            # ft_y:b*c*h*w, gm_s:1*c*h*w
            for ft_y, gm_s in zip(feature_y, gram_style):
                gram_y = gram_matrix(ft_y)
                style_loss += criterion(gram_y, gm_s.expand_as(gram_y))
            style_loss *= opt.style_weight
            
            total_loss = content_loss + style_loss
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            #import ipdb
            #ipdb.set_trace()
            # 获取tensor的值 tensor.item()   tensor.tolist()
            content_meter.add(content_loss.item())
            style_meter.add(style_loss.item())
            total_meter.add(total_loss.item())
            
            # step5.3: visualize
            if (ii+1)%opt.print_freq == 0 and opt.vis:
                # 为什么总是以这种形式进行debug
                if os.path.exists(opt.debug_file):
                    import ipdb
                    ipdb.set_trace()
                vis.plot('content_loss',content_meter.value()[0])
                vis.plot('style_loss',style_meter.value()[0])
                vis.plot('total_loss',total_meter.value()[0])
                # 因为现在data和y都已经经过了normalize,变成了-2~2,所以需要把它变回去0-1
                vis.img('input',(data.data*0.225+0.45)[0].clamp(min=0,max=1))
                vis.img('output',(y.data*0.225+0.45)[0].clamp(min=0,max=1))
            
        # step 5.4 save and validate and visualize
        if (epoch+1) % opt.save_every == 0:
            t.save(transformer_net.state_dict(), 'checkpoints/%s_style.pth' % epoch)
            # 保存图片的几种方法,第七章的是 
            # tv.utils.save_image(fix_fake_imgs,'%s/%s.png' % (opt.img_save_path, epoch),normalize=True, range=(-1,1))
            # vis.save竟然没找到  我的神   
            vis.save([opt.env])
def main(arguments):
    x_data = dict()
    y_data = dict()
    tensors = dict()
    color_data = dict()
    flags = dict()
    flags['num_epochs'] = arguments.me
    flags['batch_size'] = arguments.batchsize
    flags['hold_prob'] = arguments.holdprobabilty
    flags['numberofimages'] = arguments.numberOfimages
    flags['cd'] = arguments.choosedevice
    flags['dataset_folder'] = arguments.foldername
    flags['model'] = arguments.model
    flags['learn_rate'] = 1e-4
    flags['vgg16_weights'] = 'model_weights/vgg16_weights.npz'
    flags['inception_v3_weights'] = 'model_weights/inception_v3.ckpt'
    flags['chatbot_tensor'] = []

    device_name = ['/CPU:0', '/GPU:0']
    if device_name[flags['cd']] == '/CPU:0':
        print('Using CPU')
    else:
        print('Using GPU')

    # extracting images names and get mean per color
    images_names, labels = get_images_names(number_of_images=flags['numberofimages'],
                                            orig_data=flags['dataset_folder'])
    color_data['r_mean'], color_data['g_mean'], color_data['b_mean'] = clr_mean(images_names)

    # split data to test and train_dev
    x_train_dev, x_data['x_test'], y_train_dev, y_data['y_test'] = train_test_split(images_names,
                                                                                    labels,
                                                                                    test_size=0.1,
                                                                                    random_state=17)
    assert len(set(y_data['y_test'])) == len(set(y_train_dev))

    # split train_dev to train and dev
    x_data['x_train'], x_data['x_dev'], y_data['y_train'], y_data['y_dev'] = train_test_split(x_train_dev,
                                                                                              y_train_dev,
                                                                                              test_size=0.1,
                                                                                              random_state=17)
    assert len(set(y_data['y_dev'])) == len(set(y_data['y_train'])) == len(set(y_data['y_test']))

    # one hot encode the labels
    flags['num_classes'] = len(set(y_data['y_test']))
    # print(np.array(y_train).shape, np.array(y_dev).shape, np.array(y_test).shape)
    y_data['y_train'] = np_utils.to_categorical(y_data['y_train'], flags['num_classes'])
    y_data['y_dev'] = np_utils.to_categorical(y_data['y_dev'], flags['num_classes'])
    y_data['y_test'] = np_utils.to_categorical(y_data['y_test'], flags['num_classes'])

    flags['folder'] = '{:d} classes results batchsize {:d} holdprob {:.2f}/'.format(flags['num_classes'],
                                                                                    flags['batch_size'],
                                                                                    flags['hold_prob'])
    # create folders to save results and model
    if not os.path.exists(flags['folder']):
        os.makedirs(flags['folder'])
    if not os.path.exists(flags['folder'] + 'model/'):
        os.makedirs(flags['folder'] + 'model/')

    '''
    initialize model
    '''

    # better allocate memory during training in GPU
    config = tf.ConfigProto()
    config.gpu_options.allocator_type = 'BFC'

    # create input and label tensors placeholder
    with tf.device(device_name[flags['cd']]):
        tensors['hold_prob'] = tf.placeholder_with_default(1.0, shape=(), name='hold_prob')
        if flags['model'] == 'vgg16':
            tensors['input_layer'] = tf.placeholder(tf.float32, [None, 224, 224, 3], 'input_layer')
        else:
            tensors['input_layer'] = tf.placeholder(tf.float32, [None, 299, 299, 3], 'input_layer')
        tensors['labels_tensor'] = tf.placeholder(tf.float32, [None, flags['num_classes']])
        if flags['chatbot_tensor']:
            tensors['chatbot_tensor'] = tf.placeholder(tf.float32, [None, 10], 'chatbot_tensor')
        else:
            tensors['chatbot_tensor'] = []

            # start tensorflow session
    with tf.Session(config=config) as sess:
        # create the vgg16 model
        tic = time.clock()
        if flags['model'] == 'vgg16':
            model = Vgg16(tensors['input_layer'], tensors['chatbot_tensor'], flags['vgg16_weights'], sess,
                          hold_prob=tensors['hold_prob'],
                          num_classes=flags['num_classes'])
        else:
            model = InceptionV3(tensors['input_layer'], flags['inception_v3_weights'], sess,
                                flags['hold_prob'],
                                num_classes=flags['num_classes'])
        toc = time.clock()
        print('loading model time: ', toc-tic)

        writer = tf.summary.FileWriter('tensorboard')
        writer.add_graph(sess.graph)
        print('start tensorboard')

        # train, dev, and test
        start_training(x_data=x_data, y_data=y_data, flags=flags, color_data=color_data, session=sess,
                       tensors=tensors, last_fc=model.last_layer)
def train():
    if mx.context.num_gpus() > 0:
        ctx = mx.gpu()
    else:
        raise RuntimeError('There is no GPU device!')

    # loading configs
    args = Options().parse()
    cfg = Configs(args.config_path)
    # set logging level
    logging.basicConfig(level=logging.INFO)
    # set random seed
    np.random.seed(cfg.seed)

    # build dataset and loader
    content_dataset = ImageFolder(cfg.content_dataset, cfg.img_size, ctx=ctx)
    style_dataset = StyleLoader(cfg.style_dataset, cfg.style_size, ctx=ctx)
    content_loader = gluon.data.DataLoader(content_dataset, batch_size=cfg.batch_size, \
                                            last_batch='discard')

    vgg = Vgg16()
    vgg._init_weights(fixed=True, pretrain_path=cfg.vgg_check_point, ctx=ctx)

    style_model = Net(ngf=cfg.ngf)
    if cfg.resume is not None:
        print("Resuming from {} ...".format(cfg.resume))
        style_model.collect_params().load(cfg.resume, ctx=ctx)
    else:
        style_model.initialize(mx.initializer.MSRAPrelu(), ctx=ctx)
    print("Style model:")
    print(style_model)

    # build trainer
    lr_sche = mx.lr_scheduler.FactorScheduler(
        step=170000,
        factor=0.1,
        base_lr=cfg.base_lr
        #warmup_begin_lr=cfg.base_lr/3.0,
        #warmup_steps=300,
    )
    opt = mx.optimizer.Optimizer.create_optimizer('adam', lr_scheduler=lr_sche)
    trainer = gluon.Trainer(style_model.collect_params(), optimizer=opt)

    loss_fn = gluon.loss.L2Loss()

    logging.info("Start training with total {} epoch".format(cfg.total_epoch))
    iteration = 0
    total_time = 0.0
    num_batch = content_loader.__len__() * cfg.total_epoch
    for epoch in range(cfg.total_epoch):
        sum_content_loss = 0.0
        sum_style_loss = 0.0
        for batch_id, content_imgs in enumerate(content_loader):
            iteration += 1
            s = time.time()
            style_image = style_dataset.get(batch_id)

            style_vgg_input = subtract_imagenet_mean_preprocess_batch(
                style_image.copy())
            style_image = preprocess_batch(style_image)
            style_features = vgg(style_vgg_input)
            style_features = [
                style_model.gram.gram_matrix(mx.nd, f) for f in style_features
            ]

            content_vgg_input = subtract_imagenet_mean_preprocess_batch(
                content_imgs.copy())
            content_features = vgg(content_vgg_input)[1]

            with autograd.record():
                y = style_model(content_imgs, style_image)
                y = subtract_imagenet_mean_batch(y)
                y_features = vgg(y)

                content_loss = 2 * cfg.content_weight * loss_fn(
                    y_features[1], content_features)
                style_loss = 0.0
                for m in range(len(y_features)):
                    gram_y = style_model.gram.gram_matrix(mx.nd, y_features[m])
                    _, C, _ = style_features[m].shape
                    gram_s = mx.nd.expand_dims(style_features[m],
                                               0).broadcast_to((
                                                   gram_y.shape[0],
                                                   1,
                                                   C,
                                                   C,
                                               ))
                    style_loss = style_loss + 2 * cfg.style_weight * loss_fn(
                        gram_y, gram_s)
                total_loss = content_loss + style_loss
                total_loss.backward()

            trainer.step(cfg.batch_size)
            mx.nd.waitall()
            e = time.time()
            total_time += e - s
            sum_content_loss += content_loss[0]
            sum_style_loss += style_loss[0]
            if iteration % cfg.log_interval == 0:
                itera_sec = total_time / iteration
                eta_str = str(
                    datetime.timedelta(seconds=int((num_batch - iteration) *
                                                   itera_sec)))
                mesg = "{} Epoch [{}]:\t[{}/{}]\tTime:{:.2f}s\tETA:{}\tlr:{:.4f}\tcontent: {:.3f}\tstyle: {:.3f}\ttotal: {:.3f}".format(
                    time.strftime("%H:%M:%S",
                                  time.localtime()), epoch + 1, batch_id + 1,
                    content_loader.__len__(), itera_sec, eta_str,
                    trainer.optimizer.learning_rate,
                    sum_content_loss.asnumpy()[0] / (batch_id + 1),
                    sum_style_loss.asnumpy()[0] / (batch_id + 1),
                    (sum_content_loss + sum_style_loss).asnumpy()[0] /
                    (batch_id + 1))
                logging.info(mesg)
                ctx.empty_cache()
        save_model_filename = "Epoch_" + str(epoch + 1) +  "_" + str(time.ctime()).replace(' ', '_') + \
                "_" + str(cfg.content_weight) + "_" + str(cfg.style_weight) + ".params"
        if not os.path.isdir(cfg.save_model_dir):
            os.mkdir(cfg.save_model_dir)
        save_model_path = os.path.join(cfg.save_model_dir, save_model_filename)
        logging.info("Saving parameters to {}".format(save_model_path))
        style_model.collect_params().save(save_model_path)
Example #7
0
 def __init__(self, batch_size):
     super(Generator, self).__init__()
     self.vgg16 = Vgg16()
Example #8
0
def iter_train(**kwargs):
    config.parse(kwargs)

    # ============================================ Visualization =============================================
    # vis = Visualizer(port=2333, env=config.env)
    # vis.log('Use config:')
    # for k, v in config.__class__.__dict__.items():
    #     if not k.startswith('__'):
    #         vis.log(f"{k}: {getattr(config, k)}")

    # ============================================= Prepare Data =============================================
    train_data = VB_Dataset(config.train_paths,
                            phase='train',
                            num_classes=config.num_classes,
                            useRGB=config.useRGB,
                            usetrans=config.usetrans,
                            padding=config.padding,
                            balance=config.data_balance)
    val_data = VB_Dataset(config.test_paths,
                          phase='val',
                          num_classes=config.num_classes,
                          useRGB=config.useRGB,
                          usetrans=config.usetrans,
                          padding=config.padding,
                          balance=config.data_balance)
    train_dist, val_dist = train_data.dist(), val_data.dist()
    train_data_scale, val_data_scale = train_data.scale, val_data.scale
    print('Training Images:', train_data.__len__(), 'Validation Images:',
          val_data.__len__())
    print('Train Data Distribution:', train_dist, 'Val Data Distribution:',
          val_dist)

    train_dataloader = DataLoader(train_data,
                                  batch_size=config.batch_size,
                                  shuffle=True,
                                  num_workers=config.num_workers)
    val_dataloader = DataLoader(val_data,
                                batch_size=config.batch_size,
                                shuffle=False,
                                num_workers=config.num_workers)

    # ============================================= Prepare Model ============================================
    # model = ResNet18(num_classes=config.num_classes)
    # model = ResNet34(num_classes=config.num_classes)
    # model = ResNet50(num_classes=config.num_classes)
    model = Vgg16(num_classes=config.num_classes)
    # model = AlexNet(num_classes=config.num_classes)
    # model = densenet_collapse(num_classes=config.num_classes)
    # model = ShallowVgg(num_classes=config.num_classes)
    # model = CustomedNet(num_classes=config.num_classes)
    # model = DualNet(num_classes=config.num_classes)
    # model = SkipResNet18(num_classes=config.num_classes)
    # model = DensResNet18(num_classes=config.num_classes)
    # model = GuideResNet18(num_classes=config.num_classes)
    # print(model)

    if config.load_model_path:
        model.load(config.load_model_path)
    if config.use_gpu:
        model.cuda()
    if config.parallel:
        model = torch.nn.DataParallel(model,
                                      device_ids=list(range(
                                          config.num_of_gpu)))

    # =========================================== Criterion and Optimizer =====================================
    # weight = torch.Tensor([1/dist['0'], 1/dist['1'], 1/dist['2'], 1/dist['3']])
    # weight = torch.Tensor([1/dist['0'], 1/dist['1']])
    # weight = torch.Tensor([dist['1'], dist['0']])
    # weight = torch.Tensor([1, 10])
    # vis.log(f'loss weight: {weight}')
    # print('loss weight:', weight)
    # weight = weight.cuda()

    criterion = torch.nn.CrossEntropyLoss()
    # criterion = torch.nn.CrossEntropyLoss(weight=weight)
    # criterion = LabelSmoothing(size=config.num_classes, smoothing=0.2)
    # criterion = FocalLoss(gamma=4, alpha=None)

    lr = config.lr
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 weight_decay=config.weight_decay)

    # ================================================== Metrics ===============================================
    log_softmax = functional.log_softmax
    loss_meter = meter.AverageValueMeter()

    # ====================================== Saving and Recording Configuration =================================
    previous_AUC = 0
    previous_mAP = 0
    save_iter = 1  # 用于记录验证集上效果最好模型对应的epoch
    if config.parallel:
        save_model_dir = config.save_model_dir if config.save_model_dir else model.module.model_name
        save_model_name = config.save_model_name if config.save_model_name else model.module.model_name + '_best_model.pth'
    else:
        save_model_dir = config.save_model_dir if config.save_model_dir else model.model_name
        save_model_name = config.save_model_name if config.save_model_name else model.model_name + '_best_model.pth'
    if config.num_classes == 2:  # 2分类
        process_record = {
            'loss': [],  # 用于记录实验过程中的曲线,便于画曲线图
            'train_avg': [],
            'train_sp': [],
            'train_se': [],
            'val_avg': [],
            'val_sp': [],
            'val_se': [],
            'train_AUC': [],
            'val_AUC': []
        }
    elif config.num_classes == 3:  # 3分类
        process_record = {
            'loss': [],  # 用于记录实验过程中的曲线,便于画曲线图
            'train_sp0': [],
            'train_se0': [],
            'train_sp1': [],
            'train_se1': [],
            'train_sp2': [],
            'train_se2': [],
            'val_sp0': [],
            'val_se0': [],
            'val_sp1': [],
            'val_se1': [],
            'val_sp2': [],
            'val_se2': [],
            'train_mAUC': [],
            'val_mAUC': [],
            'train_mAP': [],
            'val_mAP': []
        }
    else:
        raise ValueError

    # ================================================== Training ===============================================
    iteration = 0
    # ****************************************** train ****************************************
    train_iter = iter(train_dataloader)
    model.train()
    while iteration < config.max_iter:
        try:
            image, label, image_path = next(train_iter)
        except:
            train_iter = iter(train_dataloader)
            image, label, image_path = next(train_iter)

        iteration += 1

        # ------------------------------------ prepare input ------------------------------------
        if config.use_gpu:
            image = image.cuda()
            label = label.cuda()

        # ---------------------------------- go through the model --------------------------------
        score = model(image)

        # ----------------------------------- backpropagate -------------------------------------
        optimizer.zero_grad()
        loss = criterion(score, label)
        # loss = criterion(log_softmax(score, dim=1), label)  # LabelSmoothing
        loss.backward()
        optimizer.step()

        # ------------------------------------ record loss ------------------------------------
        loss_meter.add(loss.item())

        if iteration % config.print_freq == 0:
            tqdm.write(
                f"iter: [{iteration}/{config.max_iter}] {config.save_model_name[:-4]} =================================="
            )

            # *************************************** validate ***************************************
            if config.num_classes == 2:  # 2分类
                model.eval()
                train_cm, train_AUC, train_sp, train_se, train_T, train_accuracy = val_2class(
                    model, train_dataloader, train_dist)
                val_cm, val_AUC, val_sp, val_se, val_T, val_accuracy = val_2class(
                    model, val_dataloader, val_dist)
                # vis.plot('loss', loss_meter.value()[0])

                model.train()

                # ------------------------------------ save model ------------------------------------
                if val_AUC > previous_AUC:  # 当测试集上的AUC升高时保存模型
                    if config.parallel:
                        if not os.path.exists(
                                os.path.join('checkpoints', save_model_dir,
                                             save_model_name[:-4])):
                            os.makedirs(
                                os.path.join('checkpoints', save_model_dir,
                                             save_model_name[:-4]))
                        model.module.save(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name[:-4],
                                         save_model_name))
                    else:
                        if not os.path.exists(
                                os.path.join('checkpoints', save_model_dir,
                                             save_model_name[:-4])):
                            os.makedirs(
                                os.path.join('checkpoints', save_model_dir,
                                             save_model_name[:-4]))
                        model.save(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name[:-4],
                                         save_model_name))
                    previous_AUC = val_AUC
                    save_iter = iteration

                # ---------------------------------- recond and print ---------------------------------
                process_record['loss'].append(loss_meter.value()[0])
                process_record['train_avg'].append((train_sp + train_se) / 2)
                process_record['train_sp'].append(train_sp)
                process_record['train_se'].append(train_se)
                process_record['train_AUC'].append(train_AUC)
                process_record['val_avg'].append((val_sp + val_se) / 2)
                process_record['val_sp'].append(val_sp)
                process_record['val_se'].append(val_se)
                process_record['val_AUC'].append(val_AUC)

                # vis.plot_many({'loss': loss_meter.value()[0],
                #                'train_avg': (train_sp + train_se) / 2, 'train_sp': train_sp, 'train_se': train_se,
                #                'val_avg': (val_sp + val_se) / 2, 'val_sp': val_sp, 'val_se': val_se,
                #                'train_AUC': train_AUC, 'val_AUC': val_AUC})
                # vis.log(f"iter: [{iteration}/{config.max_iter}] =========================================")
                # vis.log(f"lr: {optimizer.param_groups[0]['lr']}, loss: {round(loss_meter.value()[0], 5)}")
                # vis.log(f"train_avg: {round((train_sp + train_se) / 2, 4)}, train_sp: {round(train_sp, 4)}, train_se: {round(train_se, 4)}")
                # vis.log(f"val_avg: {round((val_sp + val_se) / 2, 4)}, val_sp: {round(val_sp, 4)}, val_se: {round(val_se, 4)}")
                # vis.log(f'train_AUC: {train_AUC}')
                # vis.log(f'val_AUC: {val_AUC}')
                # vis.log(f'train_cm: {train_cm}')
                # vis.log(f'val_cm: {val_cm}')
                print("lr:", optimizer.param_groups[0]['lr'], "loss:",
                      round(loss_meter.value()[0], 5))
                print('train_avg:', round((train_sp + train_se) / 2, 4),
                      'train_sp:', round(train_sp, 4), 'train_se:',
                      round(train_se, 4))
                print('val_avg:', round((val_sp + val_se) / 2, 4), 'val_sp:',
                      round(val_sp, 4), 'val_se:', round(val_se, 4))
                print('train_AUC:', train_AUC, 'val_AUC:', val_AUC)
                print('train_cm:')
                print(train_cm)
                print('val_cm:')
                print(val_cm)

            elif config.num_classes == 3:  # 3分类
                model.eval()
                train_cm, train_mAP, train_sp, train_se, train_mAUC, train_accuracy = val_3class(
                    model, train_dataloader, train_data_scale)
                val_cm, val_mAP, val_sp, val_se, val_mAUC, val_accuracy = val_3class(
                    model, val_dataloader, val_data_scale)
                model.train()

                # ------------------------------------ save model ------------------------------------
                if val_mAP > previous_mAP:  # 当测试集上的mAP升高时保存模型
                    if config.parallel:
                        if not os.path.exists(
                                os.path.join('checkpoints', save_model_dir,
                                             save_model_name[:-4])):
                            os.makedirs(
                                os.path.join('checkpoints', save_model_dir,
                                             save_model_name[:-4]))
                        model.module.save(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name[:-4],
                                         save_model_name))
                    else:
                        if not os.path.exists(
                                os.path.join('checkpoints', save_model_dir,
                                             save_model_name[:-4])):
                            os.makedirs(
                                os.path.join('checkpoints', save_model_dir,
                                             save_model_name[:-4]))
                        model.save(
                            os.path.join('checkpoints', save_model_dir,
                                         save_model_name[:-4],
                                         save_model_name))
                    previous_mAP = val_mAP
                    save_iter = iteration

                # ---------------------------------- recond and print ---------------------------------
                process_record['loss'].append(loss_meter.value()[0])
                process_record['train_sp0'].append(train_sp[0])
                process_record['train_se0'].append(train_se[0])
                process_record['train_sp1'].append(train_sp[1])
                process_record['train_se1'].append(train_se[1])
                process_record['train_sp2'].append(train_sp[2])
                process_record['train_se2'].append(train_se[2])
                process_record['train_mAUC'].append(float(train_mAUC))
                process_record['train_mAP'].append(float(train_mAP))
                process_record['val_sp0'].append(val_sp[0])
                process_record['val_se0'].append(val_se[0])
                process_record['val_sp1'].append(val_sp[1])
                process_record['val_se1'].append(val_se[1])
                process_record['val_sp2'].append(val_sp[2])
                process_record['val_se2'].append(val_se[2])
                process_record['val_mAUC'].append(float(val_mAUC))
                process_record['val_mAP'].append(float(val_mAP))

                # vis.plot_many({'loss': loss_meter.value()[0],
                #                'train_sp0': train_se[0], 'train_sp1': train_se[1], 'train_sp2': train_se[2],
                #                'train_se0': train_se[0], 'train_se1': train_se[1], 'train_se2': train_se[2],
                #                'val_sp0': val_se[0], 'val_sp1': val_se[1], 'val_sp2': val_se[2],
                #                'val_se0': val_se[0], 'val_se1': val_se[1], 'val_se2': val_se[2],
                #                'train_mAP': train_mAP, 'val_mAP': val_mAP})
                # vis.log(f"iter: [{iteration}/{config.max_iter}] =========================================")
                # vis.log(f"lr: {optimizer.param_groups[0]['lr']}, loss: {round(loss_meter.value()[0], 5)}")
                # vis.log(f"train_sp0: {round(train_sp[0], 4)}, train_sp1: {round(train_sp[1], 4)}, train_sp2: {round(train_sp[2], 4)}")
                # vis.log(f"train_se0: {round(train_se[0], 4)}, train_se1: {round(train_se[1], 4)}, train_se2: {round(train_se[2], 4)}")
                # vis.log(f"val_sp0: {round(val_sp[0], 4)}, val_sp1: {round(val_sp[1], 4)}, val_sp2: {round(val_sp[2], 4)}")
                # vis.log(f"val_se0: {round(val_se[0], 4)}, val_se1: {round(val_se[1], 4)}, val_se2: {round(val_se[2], 4)}")
                # vis.log(f"train_mAP: {train_mAP}, val_mAP: {val_mAP}")
                # vis.log(f'train_cm: {train_cm}')
                # vis.log(f'val_cm: {val_cm}')
                print("lr:", optimizer.param_groups[0]['lr'], "loss:",
                      round(loss_meter.value()[0], 5))
                print('train_sp0:', round(train_sp[0], 4), 'train_sp1:',
                      round(train_sp[1], 4), 'train_sp2:',
                      round(train_sp[2], 4))
                print('train_se0:', round(train_se[0], 4), 'train_se1:',
                      round(train_se[1], 4), 'train_se2:',
                      round(train_se[2], 4))
                print('val_sp0:', round(val_sp[0], 4), 'val_sp1:',
                      round(val_sp[1], 4), 'val_sp2:', round(val_sp[2], 4))
                print('val_se0:', round(val_se[0], 4), 'val_se1:',
                      round(val_se[1], 4), 'val_se2:', round(val_se[2], 4))
                print('mSP:', round(sum(val_sp) / 3, 5), 'mSE:',
                      round(sum(val_se) / 3, 5))
                print('train_mAUC:', train_mAUC, 'val_mAUC:', val_mAUC)
                print('train_mAP:', train_mAP, 'val_mAP:', val_mAP)
                print('train_cm:')
                print(train_cm)
                print('val_cm:')
                print(val_cm)
                print('Best mAP:', previous_mAP)

            loss_meter.reset()

        # ------------------------------------ save record ------------------------------------
        if os.path.exists(
                os.path.join('checkpoints', save_model_dir,
                             save_model_name.split('.')[0])):
            write_json(file=os.path.join('checkpoints', save_model_dir,
                                         save_model_name.split('.')[0],
                                         'process_record.json'),
                       content=process_record)

    # vis.log(f"Best Iter: {save_iter}")
    print("Best Iter:", save_iter)
Example #9
0

if __name__ == '__main__':
    """ python grad_cam.py <path_to_image>
    1. Loads an image with opencv.
    2. Preprocesses it for VGG19 and converts to a pytorch variable.
    3. Makes a forward pass to find the category index with the highest score,
    and computes intermediate activations.
    Makes the visualization. """

    args = get_args()

    # Can work with any model, but it assumes that the model has a
    # feature method, and a classifier method,
    # as in the VGG models in torchvision.
    model = Vgg16(num_classes=2)
    # model = ResNet18(num_classes=2)
    # model = Customed_ShallowNet(num_classes=2)
    # model = CAM_CNN(num_classes=2)
    model.load(args.model_path)

    grad_cam = GradCam(model=model, target_layer_names=["30"], use_cuda=args.use_cuda)
    # grad_cam = GradCam(model=model, target_layer_names=["23"], use_cuda=args.use_cuda)
    # grad_cam = GradCam(model=model, target_layer_names=["16"], use_cuda=args.use_cuda)

    if args.image_path == 'csv':  # 可以直接输入包含多张图片的csv文件
        root = '/DB/rhome/bllai/PyTorchProjects/Vertebrae_Alignment_Torch1.0'
        test_paths = [os.path.join(root, 'dataset/test_D4F1.csv')]
        test_data = SlideWindowDataset(test_paths, phase='test', useRGB=True, usetrans=True, balance=False)
        test_dataloader = DataLoader(test_data, batch_size=1, shuffle=False, num_workers=4)
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    data_train = load_data(args)
    iterator = data_train

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    vgg = Vgg16(weights=args.vgg16, requires_grad=False).to(device)
    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])

    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        count = 0
        if args.noise_count:
            noiseimg_n = np.zeros((3, args.image_size, args.image_size),
                                  dtype=np.float32)
            # Preparing noise image.
            for n_c in range(args.noise_count):
                x_n = random.randrange(args.image_size)
                y_n = random.randrange(args.image_size)
                noiseimg_n[0][x_n][y_n] += random.randrange(
                    -args.noise, args.noise)
                noiseimg_n[1][x_n][y_n] += random.randrange(
                    -args.noise, args.noise)
                noiseimg_n[2][x_n][y_n] += random.randrange(
                    -args.noise, args.noise)
                noiseimg = torch.from_numpy(noiseimg_n)
                noiseimg = noiseimg.to(device)
        for batch_id, sample in enumerate(iterator):
            x = sample['image']
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            if args.noise_count:
                # Adding the noise image to the source image.
                noisy_x = x + noiseimg
                noisy_y = transformer(noisy_x)
                noisy_y = utils.normalize_batch(noisy_y)

            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            features_y = vgg(y)
            features_x = vgg(x)

            L_feat = args.lambda_feat * mse_loss(features_y.relu2_2,
                                                 features_x.relu2_2)

            L_style = 0.
            for ft_y, gm_s in zip(features_y, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                L_style += mse_loss(gm_y, gm_s[:n_batch, :, :])
            L_style *= args.lambda_style

            L_tv = (torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) +
                    torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :])))

            L_tv *= args.lambda_tv

            if args.noise_count:
                L_pop = args.lambda_noise * F.mse_loss(y, noisy_y)
                L = L_feat + L_style + L_tv + L_pop
                print(
                    'Epoch {},{}/{}. Total loss: {}. Loss distribution: feat {}, style {}, tv {}, pop {}'
                    .format(e, batch_id, len(data_train), L.data,
                            L_feat.data / L.data, L_style.data / L.data,
                            L_tv.data / L.data, L_pop.data / L.data))
            else:
                L = L_feat + L_style + L_tv
                print(
                    'Epoch {},{}/{}. Total loss: {}. Loss distribution: feat {}, style {}, tv {}'
                    .format(e, batch_id, len(data_train), L.data,
                            L_feat.data / L.data, L_style.data / L.data,
                            L_tv.data / L.data))
            L = L_style * 1e10 + L_feat * 1e5
            L.backward()
            optimizer.step()

    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Example #11
0

def gram_matrix(x):
    (b, c, h, w) = x.size()
    ft = x.view(b, c, w * h)
    ft_t = torch.transpose(ft, 1, 2)
    gram = torch.bmm(ft, ft_t) / (c * h * w)
    return gram


dataset = ImageFolder(cfg.data, transform=tv.transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=cfg.batch_size, shuffle=True)

style, content = iter(dataloader).next()

vgg16 = Vgg16().to(device)
transform = TransNet().to(device)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(transform.parameters(), lr=cfg.lr)

plt.ion()
for epoch in range(cfg.epochs):
    print('Epoch: {}/{}'.format(epoch + 1, cfg.epochs))

    for i, (s, x) in enumerate(dataloader):
        optimizer.zero_grad()

        s = s.to(device)
        x = x.to(device)
        y = transform(x)