예제 #1
0
def train(**kwargs):
    opt = Config()
    for k, v in kwargs.items():
        setattr(opt, k, v)

    vis = Visualizer(env=opt.env)
    dataloader = get_dataloader(opt)
    _data = dataloader.dataset._data
    word2ix, ix2word = _data['word2ix'], _data['ix2word']

    # cnn = tv.models.resnet50(True)
    model = CaptionModel(opt, None, word2ix, ix2word)
    if opt.model_ckpt:
        model.load(opt.model_ckpt)

    optimizer = model.get_optimizer(opt.lr1)
    criterion = t.nn.CrossEntropyLoss()

    model.cuda()
    criterion.cuda()

    loss_meter = meter.AverageValueMeter()
    perplexity = meter.AverageValueMeter()

    for epoch in range(opt.epoch):

        loss_meter.reset()
        perplexity.reset()
        for ii, (imgs, (captions, lengths),
                 indexes) in tqdm.tqdm(enumerate(dataloader)):
            optimizer.zero_grad()
            input_captions = captions[:-1]
            imgs = imgs.cuda()
            captions = captions.cuda()

            imgs = Variable(imgs)
            captions = Variable(captions)
            input_captions = captions[:-1]
            target_captions = pack_padded_sequence(captions, lengths)[0]

            score, _ = model(imgs, input_captions, lengths)
            loss = criterion(score, target_captions)
            loss.backward()
            # clip_grad_norm(model.rnn.parameters(),opt.grad_clip)
            optimizer.step()
            loss_meter.add(loss.data[0])
            perplexity.add(t.exp(loss.data)[0])

            # 可视化
            if (ii + 1) % opt.plot_every == 0:
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()

                vis.plot('loss', loss_meter.value()[0])
                vis.plot('perplexity', perplexity.value()[0])

                # 可视化原始图片

                raw_img = _data['train']['ix2id'][indexes[0]]
                img_path = '/data/image/ai_cha/caption/ai_challenger_caption_train_20170902/caption_train_images_20170902/' + raw_img
                raw_img = Image.open(img_path).convert('RGB')
                raw_img = tv.transforms.ToTensor()(raw_img)
                vis.img('raw', raw_img)

                # raw_img = (imgs.data[0]*0.25+0.45).clamp(max=1,min=0)
                # vis.img('raw',raw_img)

                # 可视化人工的描述语句
                raw_caption = captions.data[:, 0]
                raw_caption = ''.join(
                    [_data['ix2word'][ii] for ii in raw_caption])
                vis.text(raw_caption, u'raw_caption')

                # 可视化网络生成的描述语句
                results = model.generate(imgs.data[0])
                vis.text('</br>'.join(results), u'caption')
        if (epoch + 1) % 100 == 0:
            model.save()
예제 #2
0
def train(**kwargs):
    for k, v in kwargs.items():
        setattr(opt, k, v)

    vis = Visualizer(env=opt.env)

    #获取数据
    data, word2ix, ix2word = get_data(opt)
    data = t.from_numpy(data)
    dataloader = t.utils.data.DataLoader(data,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=1)

    #模型定义
    model = PoetryModel(len(word2ix), 128, 256)
    optimizer = t.optim.Adam(model.parameters(), lr=opt.lr)
    criterion = nn.CrossEntropyLoss()

    if opt.model_path:
        model.load_state_dict(t.load(opt.model_path))

    if opt.use_gpu:
        model.cuda()
        criterion.cuda()
    loss_meter = meter.AverageValueMeter()

    for epoch in range(opt.epoch):
        loss_meter.reset()
        for li, data_ in tqdm.tqdm(enumerate(dataloader)):
            #训练
            data_ = data_.long().transpose(1, 0).contiguous()
            if opt.use_gpu: data_ = data_.cuda()
            optimizer.zero_grad()
            ##输入和目标错开
            input_, target = Variable(data_[:-1, :]), Variable(data_[1:, :])
            output, _ = model(input_)
            loss = criterion(output, target.view(-1))
            loss.backward()
            optimizer.step()

            loss_meter.add(loss.data[0])

            # 可视化
            if (1 + ii) % opt.plot_every == 0:

                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()

                vis.plot('loss', loss_meter.value()[0])
                #诗歌原文
                poetrys = [[ix2word[_word] for _word in data_[:, -iii]]
                           for _iii in range(data_.size(1))][:16]
                vis.text('</br>'.join([''.join(poetry) for poetry in poetrys]),
                         win=u'origin_poem')

                gen_poetries = []
                #分别以这几个字作为诗歌的第一个字,生成8首诗
                for word in list(u'春江花月夜凉如水'):
                    gen_poetry = ''.join(
                        generate(model, word, ix2word, word2ix))
                    gen_poetries.append(gen_poetry)
                vis.text('</br>'.join(
                    [''.join(poetry) for poetry in gen_poetries]),
                         win=u'gen_poem')

        t.save(model.state_dict(), '%s_%s.pth' % (opt.model_prefix, epoch))
예제 #3
0
def train(**kwargs):
    for k, v in kwargs.items():
        setattr(opt, k, v)

    vis = Visualizer(env=opt.env)

    # 获取数据
    data, word2ix, ix2word = get_data(opt)
    data = t.from_numpy(data)
    dataloader = t.utils.data.DataLoader(data,
                                         batch_size=opt.batch_size,
                                         shuffle=True,
                                         num_workers=1)

    # 模型定义
    model = PoetryModel(len(word2ix), 128, 256)
    optimizer = t.optim.Adam(model.parameters(), lr=opt.lr)
    criterion = nn.CrossEntropyLoss()

    if opt.model_path:
        model.load_state_dict(t.load(opt.model_path))

    if opt.use_gpu:
        model.cuda()
        criterion.cuda()
    loss_meter = meter.AverageValueMeter()

    for epoch in range(opt.epoch):
        loss_meter.reset()
        for ii, data_ in tqdm.tqdm(enumerate(dataloader)):

            # 训练
            data_ = data_.long().transpose(1, 0).contiguous()
            if opt.use_gpu: data_ = data_.cuda()
            optimizer.zero_grad()
            input_, target = Variable(data_[:-1, :]), Variable(data_[1:, :])
            output, _ = model(input_)
            loss = criterion(output, target.view(-1))
            loss.backward()
            optimizer.step()

            loss_meter.add(loss.data[0])

            # 可视化
            if (1 + ii) % opt.plot_every == 0:

                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()

                vis.plot('loss', loss_meter.value()[0])

                # 诗歌原文
                poetrys = [[ix2word[_word] for _word in data_[:, _iii]]
                           for _iii in range(data_.size(1))][:16]
                vis.text('</br>'.join([''.join(poetry) for poetry in poetrys]), win=u'origin_poem')

                gen_poetries = []
                # 分别以这几个字作为诗歌的第一个字,生成8首诗
                for word in list(u'春江花月夜凉如水'):
                    gen_poetry = ''.join(generate(model, word, ix2word, word2ix))
                    gen_poetries.append(gen_poetry)
                vis.text('</br>'.join([''.join(poetry) for poetry in gen_poetries]), win=u'gen_poem')

        t.save(model.state_dict(), '%s_%s.pth' % (opt.model_prefix, epoch))
예제 #4
0
def train(**kwargs):
    opt = Config()
    for k, v in kwargs.items():
        setattr(opt, k, v)
    device=t.device('cuda') if opt.use_gpu else t.device('cpu')

    opt.caption_data_path = 'caption.pth'  # 原始数据
    opt.test_img = ''  # 输入图片
    # opt.model_ckpt='caption_0914_1947' # 预训练的模型

    # 数据
    vis = Visualizer(env=opt.env)
    dataloader = get_dataloader(opt)
    _data = dataloader.dataset._data
    word2ix, ix2word = _data['word2ix'], _data['ix2word']

    # 模型
    model = CaptionModel(opt, word2ix, ix2word)
    if opt.model_ckpt:
        model.load(opt.model_ckpt)
    optimizer = model.get_optimizer(opt.lr)
    criterion = t.nn.CrossEntropyLoss()
   
    model.to(device)

    # 统计
    loss_meter = meter.AverageValueMeter()

    for epoch in range(opt.epoch):
        loss_meter.reset()
        for ii, (imgs, (captions, lengths), indexes) in tqdm.tqdm(enumerate(dataloader)):
            # 训练
            optimizer.zero_grad()
            imgs = imgs.to(device)
            captions = captions.to(device)
            input_captions = captions[:-1]
            target_captions = pack_padded_sequence(captions, lengths)[0]
            score, _ = model(imgs, input_captions, lengths)
            loss = criterion(score, target_captions)
            loss.backward()
            optimizer.step()
            loss_meter.add(loss.item())

            # 可视化
            if (ii + 1) % opt.plot_every == 0:
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()

                vis.plot('loss', loss_meter.value()[0])

                # 可视化原始图片 + 可视化人工的描述语句
                raw_img = _data['ix2id'][indexes[0]]
                img_path = opt.img_path + raw_img
                raw_img = Image.open(img_path).convert('RGB')
                raw_img = tv.transforms.ToTensor()(raw_img)

                raw_caption = captions.data[:, 0]
                raw_caption = ''.join([_data['ix2word'][ii] for ii in raw_caption])
                vis.text(raw_caption, u'raw_caption')
                vis.img('raw', raw_img, caption=raw_caption)

                # 可视化网络生成的描述语句
                results = model.generate(imgs.data[0])
                vis.text('</br>'.join(results), u'caption')
        model.save()
예제 #5
0
def train(**kwargs):
    for k, v in kwargs.items():
        setattr(opt, k, v)

    opt.device = t.device('cuda' if t.cuda.is_available() else 'cpu')
    device = opt.device
    vis = Visualizer(env=opt.env)

    # 获取数据
    data_all = np.load(opt.pickle_path)
    data = data_all['data']
    word2ix = data_all['word2ix'].item()
    ix2word = data_all['ix2word'].item()
    data = t.from_numpy(data)
    dataloader = DataLoader(data,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=1)

    # 模型定义
    model = PoetryModel(len(word2ix), 128, 256)
    optimizer = t.optim.Adam(model.parameters(), lr=opt.lr)
    loss_func = nn.CrossEntropyLoss()
    if opt.model_path:
        model.load_state_dict(
            t.load(opt.model_path, map_location=t.device('cpu')))
    model.to(device)

    loss_avg = 0
    for epoch in range(opt.epoch):
        for ii, data_ in tqdm(enumerate(dataloader)):
            data_ = data_.long()
            data_ = data_.to(device)
            optimizer.zero_grad()
            input_, target = data_[:, :-1], data_[:, 1:]
            output, _ = model(input_)
            loss = loss_func(output, target.reshape(-1))
            loss.backward()
            optimizer.step()

            loss_avg += loss.item()

            # 可视化
            if (ii + 1) % opt.plot_every == 0:
                vis.plot('loss', loss_avg / opt.plot_every)
                loss_avg = 0
                poetrys = [[ix2word[_word] for _word in data_[i].tolist()]
                           for i in range(data_.shape[0])][:16]
                vis.text('</br>'.join([''.join(poetry) for poetry in poetrys]),
                         win='origin_poem')

                gen_poetries = []
                for word in list('春江花月夜凉如水'):
                    gen_poetry = ''.join(
                        generate(model, word, ix2word, word2ix))
                    gen_poetries.append(gen_poetry)
                vis.text('</br>'.join(
                    [''.join(poetry) for poetry in gen_poetries]),
                         win='gen_poem')

        t.save(model.state_dict(), '%s_%s.pth' % (opt.model_prefix, epoch))
예제 #6
0
def train(**kwargs):

    for k,v in kwargs.items():
        setattr(opt,k,v)

    vis = Visualizer(env=opt.env)

    # 获取数据
    data,word2ix,ix2word = get_data(opt)
    data = t.from_numpy(data)#把数据类型转为tensor
    dataloader = t.utils.data.DataLoader(data,#初始化Dataloader类实例
                    batch_size=opt.batch_size,
                    shuffle=True,
                    num_workers=1)

    # 模型定义
    model = PoetryModel(len(word2ix), 128, 256)#(vocab_size, embedding_dim, hidden_dim)
    optimizer = t.optim.Adam(model.parameters(), lr=opt.lr)
    criterion = nn.CrossEntropyLoss()#损失函数定义为交叉熵
    
    if opt.model_path:
        model.load_state_dict(t.load(opt.model_path))

    if opt.use_gpu:
        model.cuda()
        criterion.cuda()
    loss_meter = meter.AverageValueMeter()

    for epoch in range(opt.epoch):
        loss_meter.reset()
        for ii,data_ in tqdm.tqdm(enumerate(dataloader)):    #tqdm进度条工具
            #取一个batch的数据
            # 训练

            #data_.size:(batch_size,maxlen)
            data_ = data_.long().transpose(1,0).contiguous()#转置后返回一个内存连续的有相同数据的tensor
            # if epoch==0 and ii ==0:
            #     print('size of data_ after transpose: \n',data_.size())
            if opt.use_gpu: data_ = data_.cuda()      
            optimizer.zero_grad()#梯度清零

            input_,target = Variable(data_[:-1,:]),Variable(data_[1:,:])#input_是所有句子的前maxlen-1个item的集合,
            #target是所有句子的后maxlen-1个item的集合
            #以"床前明月光"为例,输入是"床前明月",要预测"前明月光"
            output,_  = model(input_)
            #Tensor.view(-1)按照第0个维度逐个元素读取将张量展开成数组

            loss = criterion(output,target.view(-1))
            loss.backward()
            optimizer.step()
        
            loss_meter.add(loss.data[0])

            # 可视化
            if (1+ii)%opt.plot_every==0:

                if os.path.exists(opt.debug_file):#如果存在调试文件,
                    #则进入调试模式
                    ipdb.set_trace()

                vis.plot('loss',loss_meter.value()[0])
                
                # 诗歌原文
                poetrys=[ [ix2word[_word] for _word in data_[:,_iii]] #每一个句子(诗歌)的每一个item(id)要转换成文本
                                    for _iii in range(data_.size(1))][:16]#_iii的取值范围[,127]
                vis.text('</br>'.join([''.join(poetry) for poetry in poetrys]),win=u'origin_poem')
                #在visdom中输出这些句子(诗歌)中的前16个
                gen_poetries = []
                # 分别以这几个字作为诗歌的第一个字,生成8首诗
                for word in list(u'春江花月夜凉如水'):
                    gen_poetry =  ''.join(generate(model,word,ix2word,word2ix))
                    gen_poetries.append(gen_poetry)
                vis.text('</br>'.join([''.join(poetry) for poetry in gen_poetries]),win=u'gen_poem')  
        
        t.save(model.state_dict(),'%s_%s.pth' %(opt.model_prefix,epoch))
예제 #7
0
def train(**kwargs):
    for k, v in kwargs.items():
        setattr(opt, k, v)
    
    vis = Visualizer(env=opt.env)
    
    # 获取数据
    data, word2ix, ix2word = get_data(opt)
    data = t.from_numpy(data)
    dataloader = t.utils.data.DataLoader(data, batch_size=opt.batch_size, shuffle=True, num_workers=2)
    
    # 定义model
    model = PoetryModel(len(word2ix), opt.embedding_dim, opt.hidden_dim)
    # 优化器
    optimizer = t.optim.Adam(model.parameters(), lr=opt.lr)
    # Loss Function
    criterion = nn.CrossEntropyLoss()
    
    
    
    # 使用预训练的模型,为了可持续训练
    if opt.model_path and os.path.exists(opt.model_path):
        model.load_state_dict(t.load(opt.model_path))
    
    # GPU related

    if opt.use_gpu:
        model = model.to(device)
        criterion = criterion.to(device)
    
    # loss 计量器
    loss_meter = meter.AverageValueMeter()
    
    # for loop
    for epoch in range(opt.epoch):
        loss_meter.reset()
        
        # for : batching dataset
        for i, data_ in tqdm.tqdm(enumerate(dataloader)):
            
            # 训练
            # data_ 
            # size: [128, 125]  每次取128行,每行一首诗,长度为125
            # type: Tensor
            # dtype: torch.int32 应该转成long
            
            # 这行代码信息量很大:
            # 第一步:int32 to long
            # 第二步:将行列互换,为了并行计算的需要
            # 第三步:将数据放置在连续内存里,避免后续有些操作报错
            data_ = data_.long().transpose(0, 1).contiguous()
            
            # GPU related
            if opt.use_gpu:
                data_ = data_.to(device)
            
            # 到这里 data_.dtype又变成了torch.int64
            # print(data_.dtype)
            
            # 清空梯度
            optimizer.zero_grad()
            
            # 错位训练,很容易理解
            # 把前n-1行作为input,把后n-1行作为target  :  model的输入
            # 这么做还是为了并行计算的需要
            # input_ 加下划线是为了和built_in function input区分开
            input_, target = data_[:-1, :], data_[1:, :]
            
            # model的返回值 output和hidden
            # 这里hidden没什么用
            output, _ = model(input_)
            
            # 计算loss
            target = target.view(-1)
            
            # 新的target.size() [15872]  124 * 128 = 15872
            # output.size()  [15872, 8293] 8293 是词汇量的大小
            
            loss = criterion(output, target)
            
            # 反向传播
            loss.backward()
            
            # optimizer梯度下降更新参数
            optimizer.step()
            
            loss_meter.add(loss.data[0])

            # 可视化
            if (1 + i) % opt.plot_every == 0:

                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()

                vis.plot('loss', loss_meter.value()[0])

                # 诗歌原文
                poetrys = [[ix2word[_word.item()] for _word in data_[:, _iii]]
                           for _iii in range(data_.size(1))][:16]
                vis.text('</br>'.join([''.join(poetry) for poetry in poetrys]), win=u'origin_poem')

                gen_poetries = []
                # 分别以这几个字作为诗歌的第一个字,生成8首诗
                for word in list(u'春江花月夜凉如水'):
                    gen_poetry = ''.join(generate(model, word, ix2word, word2ix))
                    gen_poetries.append(gen_poetry)
                vis.text('</br>'.join([''.join(poetry) for poetry in gen_poetries]), win=u'gen_poem')
        # 迭代一次epoch,保存一下模型
        t.save(model.state_dict(), '%s_%s.pth' % (opt.model_prefix, epoch))
예제 #8
0
def train():
    model = IMAGE_AI_MODEL()
    model.train()
    model.cuda()
    criterion = t.nn.NLLLoss()
    optimizer = t.optim.Adam(model.parameters(), lr=1e-3)
    dataloader = get_dataloader()
    data_set = dataloader.dataset
    print(len(data_set))
    ix2word = dataloader.dataset.ix2word
    _data = dataloader.dataset._data
    vis = Visualizer(env='word_embedding_caption')
    loss_meter = meter.AverageValueMeter()
    for epoch in range(10):
        loss_meter.reset()
        for ii, (img_lows, img_highs, cap_tensor, lengths,
                 indexs) in tqdm.tqdm(enumerate(dataloader)):
            optimizer.zero_grad()
            loss = 0
            bitch_target_length = 0
            for i in range(8):
                decoder_hidden = img_lows[[i]].unsqueeze(0)
                cell_hidden = decoder_hidden.clone()
                encoder_outputs = img_highs[i]
                target_tensor = cap_tensor[i]
                target_length = lengths[i]
                bitch_target_length += target_length
                decoder_input = t.tensor([0])
                decoder_hidden = decoder_hidden.cuda()
                cell_hidden = cell_hidden.cuda()
                encoder_outputs = encoder_outputs.cuda()
                target_tensor = target_tensor.cuda()
                decoder_input = decoder_input.cuda()
                raw_img = _data['ix2id'][indexs[i]]
                img_path_q = 'ai_challenger_caption_train_20170902/caption_train_images_20170902/'
                img_path = img_path_q + raw_img
                ture_words = []
                for w in range(target_length):
                    ture_words.append(ix2word[target_tensor[w].item()])
                    ture_words.append('|')
                decoded_words = []
                for di in range(target_length):
                    decoder_output, decoder_hidden, cell_hidden, decoder_attention = model(
                        decoder_input, decoder_hidden, cell_hidden,
                        encoder_outputs)
                    loss += criterion(decoder_output, target_tensor[[di]])
                    decoder_input = target_tensor[[di]]
                    topv, topi = decoder_output.data.topk(1)
                    if topi.item() == 2:
                        decoded_words.append('<EOS>')
                        break
                    else:
                        decoded_words.append(ix2word[topi.item()])
            loss.backward()
            loss_batch = loss.item() / bitch_target_length
            loss_meter.add(loss_batch)
            optimizer.step()
            plot_every = 10
            if (ii + 1) % plot_every == 0:
                vis.plot('loss', loss_meter.value()[0])
                raw_img = Image.open(img_path).convert('RGB')
                raw_img = tv.transforms.ToTensor()(raw_img)
                vis.img('raw', raw_img)
                raw_caption = ''.join(decoded_words)
                vis.text(raw_caption, win='raw_caption')
                ture_caption = ''.join(ture_words)
                vis.text(ture_caption, win='ture_caption')
        # save
        prefix = 'IMAGE_AI_MODEL'
        path = '{prefix}_{time}'.format(prefix=prefix,
                                        time=time.strftime('%m%d_%H%M'))
        t.save(model.state_dict(), path)
예제 #9
0
def train(**kwargs):
    opt = Config()
    for k, v in kwargs.items():
        setattr(opt, k, v)
    device = t.device('cuda') if opt.use_gpu else t.device('cpu')

    opt.caption_data_path = 'caption.pth'  # 原始数据
    opt.test_img = ''  # 输入图片
    # opt.model_ckpt='caption_0914_1947' # 预训练的模型

    # 数据
    vis = Visualizer(env=opt.env)
    dataloader = get_dataloader(opt)
    _data = dataloader.dataset._data
    word2ix, ix2word = _data['word2ix'], _data['ix2word']

    # 模型
    model = CaptionModel(opt, word2ix, ix2word)
    if opt.model_ckpt:
        model.load(opt.model_ckpt)
    optimizer = model.get_optimizer(opt.lr)
    criterion = t.nn.CrossEntropyLoss()

    model.to(device)

    # 统计
    loss_meter = meter.AverageValueMeter()

    for epoch in range(opt.epoch):
        loss_meter.reset()
        for ii, (imgs, (captions, lengths),
                 indexes) in tqdm.tqdm(enumerate(dataloader)):
            # 训练
            optimizer.zero_grad()
            imgs = imgs.to(device)
            captions = captions.to(device)
            input_captions = captions[:-1]
            target_captions = pack_padded_sequence(captions, lengths)[0]
            score, _ = model(imgs, input_captions, lengths)
            loss = criterion(score, target_captions)
            loss.backward()
            optimizer.step()
            loss_meter.add(loss.item())

            # 可视化
            if (ii + 1) % opt.plot_every == 0:
                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()

                vis.plot('loss', loss_meter.value()[0])

                # 可视化原始图片 + 可视化人工的描述语句
                raw_img = _data['ix2id'][indexes[0]]
                img_path = opt.img_path + raw_img
                raw_img = Image.open(img_path).convert('RGB')
                raw_img = tv.transforms.ToTensor()(raw_img)

                raw_caption = captions.data[:, 0]
                raw_caption = ''.join(
                    [_data['ix2word'][ii] for ii in raw_caption])
                vis.text(raw_caption, u'raw_caption')
                vis.img('raw', raw_img, caption=raw_caption)

                # 可视化网络生成的描述语句
                results = model.generate(imgs.data[0])
                vis.text('</br>'.join(results), u'caption')
        model.save()
예제 #10
0
파일: trainer.py 프로젝트: Seolen/Vnet_Seg
class Trainer(object):
    def __init__(self, **kwargs):
        ''' externar -> init param -> data prepare -> model load -> learning def'''

        opt.parse(kwargs)
        self.env = opt.env
        self.vis = Visualizer(opt.env)
        self.vis.text(opt.notes)

        self.evaluator = Evaluator(opt.num_classes)
        self.best_acc = self.best_epoch = -1

        self.train_loader, self.val_loader = self.data_process()
        self.model = self.set_model()
        self.criterion, self.optimizer, self.scheduler = self.learning(
            self.model)

    def forward(self):
        ''' train and val '''

        vis, evaluator = self.vis, self.evaluator
        train_loader, val_loader, model = self.train_loader, self.val_loader, self.model
        criterion, optimizer, scheduler = self.criterion, self.optimizer, self.scheduler

        for epoch_i in tqdm(range(opt.epoch)):
            # adjust lerning rate
            if opt.use_stepLR:
                scheduler.step()
            else:  # poly_lr_scheduler
                power = 0.9
                new_lr = opt.lr * (1 - epoch_i / opt.epoch)**power
                for param_group in optimizer.param_groups:
                    param_group['lr'] = new_lr

            train_loss, train_mean_ious, train_ious, train_dice, train_hausdorff = self.train(
                model, train_loader, criterion, optimizer, evaluator)
            test_loss, test_mean_ious, test_ious, test_dice, test_hausdorff = self.val(
                model, val_loader, criterion, evaluator)
            vis.plot('loss', [train_loss, test_loss])
            vis.plot('mIOU', [train_mean_ious, test_mean_ious])
            vis.plot('train_IoU', train_ious[1:])
            vis.plot('test_IoU', test_ious[1:])
            vis.plot('train_dice', train_dice[1:])
            vis.plot('test_dice', test_dice[1:])
            vis.plot('hausdorff_distance', [train_hausdorff, test_hausdorff])

            if self.acc_update(test_mean_ious):
                self.model_save(model,
                                epoch_i,
                                test_mean_ious,
                                name=opt.env + '_best.pth')
        vis.text('Best accuracy %f' % self.best_acc)
        print('Best accuracy', self.best_acc)

    def train(self, model, dataloader, criterion, optimizer, evaluator):
        model.train()
        evaluator.reset()
        loss_meter = meter.AverageValueMeter()

        for batch_idx, sample in tqdm(enumerate(dataloader)):
            img, target = sample['image'], sample['label']
            # print(img.data.shape, target.data.shape)

            img, target = Variable(img.cuda()), Variable(target.cuda())

            optimizer.zero_grad()
            output = model(img)
            loss = criterion(output, target)

            loss.backward()
            optimizer.step()

            # metrics, prediction
            loss_meter.add(loss.data.item())
            pred = output.data.cpu().argmax(1)  # shape(1, 64, 64, 32)
            evaluator.add_batch(target.data.cpu().numpy(), pred.numpy())
            '''
            # For Classification Prediction
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()
            '''

        return loss_meter.value()[0], evaluator.Mean_Intersection_over_Union(ignore_index=0), evaluator.Intersection_over_Union(), \
               evaluator.Dice_coefficient(), evaluator.Hausdorff()

    def val(self, model, dataloader, criterion, evaluator):

        model.eval()
        evaluator.reset()
        loss_meter = meter.AverageValueMeter()

        for batch_idx, sample in tqdm(enumerate(dataloader)):
            img, target = sample['image'], sample['label']
            # print(img.data.shape)
            img, target = Variable(img.cuda()), Variable(target.cuda())

            with torch.no_grad():
                output = model(img)
            loss = criterion(output, target)

            # metrics, prediction
            loss_meter.add(loss.data.item())
            pred = output.data.cpu().argmax(1)
            evaluator.add_batch(target.data.cpu().numpy(), pred.numpy())
            '''
            # For Classification Prediction
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()
            '''

        return loss_meter.value()[0], evaluator.Mean_Intersection_over_Union(ignore_index=0), evaluator.Intersection_over_Union(), \
               evaluator.Dice_coefficient(), evaluator.Hausdorff()

    def data_process(self):
        dataset = getattr(data, opt.dataset)
        trainset = dataset(opt.datadir,
                           split='train',
                           use_truncated=opt.use_truncated)
        valset = dataset(opt.datadir,
                         split='test',
                         use_truncated=opt.use_truncated)
        train_loader = DataLoader(trainset,
                                  batch_size=opt.batch_size,
                                  shuffle=False,
                                  num_workers=opt.num_workers)
        val_loader = DataLoader(valset,
                                batch_size=opt.val_batch_size,
                                shuffle=False,
                                num_workers=opt.num_workers)
        print('Data processed.')

        return train_loader, val_loader

    def set_model(self):
        model = VNet(num_classes=opt.num_classes)
        if opt.use_init:
            model.weights_init()
        model.cuda()

        if opt.use_parallel:
            model = nn.DataParallel(model)
        if opt.use_pretrained:
            pt = torch.load(opt.pretrain + opt.pretrained_name)['state_dict']
            model.load_state_dict(pt)

        return model

    def learning(self, model):
        if opt.use_balance_weight:
            weights = torch.FloatTensor(opt.balance_weight).cuda()
        else:
            weights = None
        if not opt.use_dice_loss:
            criterion = nn.CrossEntropyLoss(weight=weights)
        else:
            criterion = DiceLoss()
        if not opt.use_perparam:
            optimizer = optim.SGD(model.parameters(),
                                  lr=opt.lr,
                                  momentum=opt.momentum,
                                  weight_decay=opt.weight_decay)
        else:
            train_params = [{
                'params': model.get_1x_lr_params(),
                'lr': opt.lr
            }, {
                'params': model.get_10x_lr_params(),
                'lr': opt.lr * 10
            }]
            optimizer = optim.SGD(train_params,
                                  momentum=opt.momentum,
                                  weight_decay=opt.weight_decay)
        # optimizer = optim.Adam(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay)
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              step_size=20,
                                              gamma=0.1)

        return criterion, optimizer, scheduler

    def acc_update(self, cur_acc):
        if cur_acc > self.best_acc:
            self.best_acc = cur_acc
            return True
        return False

    def model_init(self, model):
        for m in model.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
                # m.bias.data.zero_()

    def model_save(self, model, epoch, metric, name=opt.env + '_best.pth'):
        prefix = 'results/checkpoints/'
        torch.save(
            {
                'epoch': epoch,
                'state_dict': model.module.state_dict(),
                'metric': metric
            }, prefix + name)