示例#1
0
 def do_train(self):
     pre_loss=1
     dataset=LungDataset( self.img_dir,augument=True)
     lr=0.01
     for i in range(self.max_run):
         self.dice_loss_meter.reset()
         start_time = time.time()
         train_loader = DataLoader(dataset,batch_size = self.batch_size,shuffle = True,num_workers = 1,pin_memory=True,drop_last=True)
         for batch_idx, (img_, mask_,file) in enumerate(train_loader):
             self.global_steps+=1
             img=img_.numpy()
             mask=mask_.numpy()
             img=img.transpose([0,2,3,4,1])
             mask=mask.transpose([0,2,3,4,1])
             feed_dict={self.image_batch:img,self.label_batch:mask,self.learning_rate:lr}
             _,prob_=self.sess.run([self.train_op,self.prob],feed_dict=feed_dict)
             # summary_writer.add_summary(summary, iters)
             dice=self.loss_dice.eval(feed_dict,session=self.sess)
             dice_loss_meter.add(dice)
             if batch_idx>10:
                 vis.plot('dice_loss',self.dice_loss_meter.value()[0])
                 vis.plot("dice",1-self.dice_loss_meter.value()[0])
             img_k=np.zeros((64*8,64*8),dtype=np.float32)
             mask_k=np.zeros((64*8,64*8),dtype=np.float32)
             pred_k=np.zeros((64*8,64*8),dtype=np.float32)
             l=0
             for i_ in range(8):
             for j in range(8):
                 img_k[i_*64:i_*64+64,j*64:j*64+64]=img[0,l,:,:,0]
                 mask_k[i_*64:i_*64+64,j*64:j*64+64]=mask[0,l,:,:,0]
                 pred_k[i_*64:i_*64+64,j*64:j*64+64]=prob_[0,l,:,:,0]
                 l=l+1
             if np.sum(prob_)<5:
                 vis.plot('pred__',np.sum(prob_))
             vis.img('input',torch.from_numpy(img_k))
             vis.img('mask',torch.from_numpy(mask_k))
             vis.img('pred',torch.from_numpy(pred_k))
             
             if self.global_steps%50==0:
                 logitss=self.logits.eval(feed_dict,session=self.sess)
                 print("logits  %.4f"%np.sum(logitss))
                 losss=self.cross_loss.eval(feed_dict,session=self.sess)
                 dice=self.loss_dice.eval(feed_dict,session=self.sess)
                 print("Epoch: [%2d]  [%4d] ,time: %4.4f,dice_loss:%.8f,dice:%.8f,cross_loss:%.8f" % \
                       (i,batch_idx,time.time() - start_time,dice,1-dice,losss))
     if self.dice_loss_meter.value()[0]>pre_loss:
         lr=lr*0.95
         print "pre_loss: ",pre_loss," now_loss: ",self.dice_loss_meter.value()[0]," lr: ",lr
     pre_loss = self.dice_loss_meter.value()[0]
     if lr<1e-7:
         save(self.saver,self.sess,self.restore_from,self.global_steps,self.model_dir,train_tag="mask_predict")
         print "stop for lr<1e-7"
         break
     if i%10==0:
         save(self.saver,self.sess,self.restore_from,self.global_steps,self.model_dir,train_tag="mask_predict")
示例#2
0
def main():
    pre_loss = 1
    dice_loss_meter = tnt.meter.AverageValueMeter()
    vis.vis.texts = ''
    # loss_meter =tnt.meter.AverageValueMeter()
    image_batch = tf.placeholder(tf.float32, shape=[None, 64, 64, 64, 1])
    label_batch = tf.placeholder(tf.float32, shape=[None, 64, 64, 64, 1])
    net = Unet3D({'data': image_batch}, batch_size=batch_size, keep_prob=0.5)
    dataset = LungDataset(luna_data + 'train/', augument=True)
    prob = net.layers['result']
    logits = net.layers['conv_8']
    logitsum = tf.summary.histogram("logits", logits)
    conv7_1 = net.layers['conv7_1']
    conv7_1_sum = tf.summary.histogram("conv7_1", conv7_1)
    conv7_2 = net.layers['conv7_2']
    con7_2sum = tf.summary.histogram("conv7_2", conv7_2)
    print "logits--------------:", logits.shape
    cross_loss = pixelwise_cross_entropy(logits, label_batch)
    # cross_loss_sum=tf.summary.scalar("cross_loss",cross_loss)
    all_trainable = tf.trainable_variables()
    restore_var = tf.global_variables()

    loss_dice = dice_coef_loss(prob, label_batch)
    dice = 1 - loss_dice

    extra_loss = extraLoss(prob, label_batch)
    Loss = loss_dice  #-0.001*tf.norm(prob-0.5)#+10*extraLoss1(prob,label_batch)
    # dice_sum=tf.summary.scalar("dice",dice)
    global iters
    learning_rate = tf.placeholder(tf.float32)
    # lr_sum=tf.summary.scalar("learning_rate",learning_rate)
    optimiser = tf.train.MomentumOptimizer(learning_rate, 0.99)
    gradients = tf.gradients(Loss, all_trainable)
    clipped_gradients, norm = tf.clip_by_global_norm(gradients, 1.)
    train_op = optimiser.apply_gradients(zip(clipped_gradients, all_trainable))

    # summarypic=prob[:,32]
    # origin_sum=tf.summary.image("image_batch",image_batch[:,32,:,:])
    # mask_sum=tf.summary.image("label_batch",label_batch[:,32,:,:]+image_batch[:,32,:,:])
    # img_sum=tf.summary.image("prediction",tf.add(summarypic,image_batch[:,32,:,:]))
    # summary_writer = tf.summary.FileWriter(logs,graph=tf.get_default_graph())
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()
    # all_sum=tf.summary.merge([cross_loss_sum,img_sum,origin_sum,mask_sum,dice_sum,logitsum,conv7_1_sum,con7_2sum])
    sess.run(init)

    # Saver for storing checkpoints of the model.
    saver = tf.train.Saver(var_list=restore_var, max_to_keep=40)
    lr = 0.01
    # Load variables if the checkpoint is provided.
    if restore_from is not None:
        loader = tf.train.Saver(var_list=restore_var)
        load(loader, sess, restore_from, "unet3d_v1")
    for i in range(max_run):
        dice_loss_meter.reset()
        start_time = time.time()
        train_loader = DataLoader(dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=1,
                                  pin_memory=True,
                                  drop_last=True)
        for batch_idx, (img_, mask_, file) in enumerate(train_loader):
            iters += 1
            img = img_.numpy()
            mask = mask_.numpy()
            img = img.transpose([0, 2, 3, 4, 1])
            mask = mask.transpose([0, 2, 3, 4, 1])
            feed_dict = {
                image_batch: img,
                label_batch: mask,
                learning_rate: lr
            }
            _, prob_ = sess.run([train_op, prob], feed_dict=feed_dict)
            # summary_writer.add_summary(summary, iters)
            dice = loss_dice.eval(feed_dict, session=sess)
            dice_loss_meter.add(dice)
            all_loss = Loss.eval(feed_dict, session=sess)
            if batch_idx > 10:
                vis.plot('dice_loss', dice_loss_meter.value()[0])
                vis.plot("dice", 1 - dice_loss_meter.value()[0])
            # vis.plot("all_loss",all_loss)
            # vis.img('input',img_[0,0,32,:,:].cpu().float())
            # vis.img('mask',mask_[0,0,32,:,:].cpu().float())

            img_k = np.zeros((64 * 8, 64 * 8), dtype=np.float32)
            mask_k = np.zeros((64 * 8, 64 * 8), dtype=np.float32)
            pred_k = np.zeros((64 * 8, 64 * 8), dtype=np.float32)
            l = 0
            # print file
            for i_ in range(8):
                for j in range(8):
                    img_k[i_ * 64:i_ * 64 + 64,
                          j * 64:j * 64 + 64] = img[0, l, :, :, 0]
                    mask_k[i_ * 64:i_ * 64 + 64,
                           j * 64:j * 64 + 64] = mask[0, l, :, :, 0]
                    pred_k[i_ * 64:i_ * 64 + 64,
                           j * 64:j * 64 + 64] = prob_[0, l, :, :, 0]
                    l = l + 1
            if np.sum(prob_) < 5:
                vis.plot('pred__', np.sum(prob_))
            vis.img('input', torch.from_numpy(img_k))
            vis.img('mask', torch.from_numpy(mask_k))
            vis.img('pred', torch.from_numpy(pred_k))
            # if dice<0.01:
            #                 img_l=np.zeros((64*5,64*5),dtype=np.float32)
            #                 mask_l=np.zeros((64*5,64*5),dtype=np.float32)
            #                 pred_l=np.zeros((64*5,64*5),dtype=np.float32)
            #                 l=0
            #                 for i in range(5):
            #                     for j in range(5):
            #                         l=l+1
            #                         img_l[i*64:i*64+64,j*64:j*64+64]=img[0,12+l,:,:,0]
            #                         mask_l[i*64:i*64+64,j*64:j*64+64]=mask[0,12+l,:,:,0]
            #                         pred_l[i*64:i*64+64,j*64:j*64+64]=prob_[0,12+l,:,:,0]

            #                 vis.img('input_0.01_loss',torch.from_numpy(img_l))
            #                 vis.img('mask_0.01_loss',torch.from_numpy(mask_l))
            #                 vis.img('pred_0.01_loss',torch.from_numpy(pred_l))

            if iters % 50 == 0:
                logitss = logits.eval(feed_dict, session=sess)
                print("logits  %.4f" % np.sum(logitss))
                losss = cross_loss.eval(feed_dict, session=sess)
                dice = loss_dice.eval(feed_dict, session=sess)
                all_loss = Loss.eval(feed_dict, session=sess)
                print("Epoch: [%2d]  [%4d] ,time: %4.4f,all_loss:%.8f,dice_loss:%.8f,dice:%.8f,cross_loss:%.8f" % \
                      (i,batch_idx,time.time() - start_time,all_loss,dice,1-dice,losss))
        if dice_loss_meter.value()[0] > pre_loss:
            lr = lr * 0.95
            print "pre_loss: ", pre_loss, " now_loss: ", dice_loss_meter.value(
            )[0], " lr: ", lr
        pre_loss = dice_loss_meter.value()[0]
        if lr < 1e-7:
            save(saver,
                 sess,
                 models_path,
                 iters,
                 "unet3d_v1",
                 train_tag="mask_predict")
            print "stop for lr<1e-7"
            break
        if i % 10 == 0:
            save(saver,
                 sess,
                 models_path,
                 iters,
                 "unet3d_v1",
                 train_tag="mask_predict")
示例#3
0
def train():
    args = parse_args()
    if not os.path.exists(args.snapshot_dir): 
        os.makedirs(args.snapshot_dir)
    if not os.path.exists(args.out_dir): 
        os.makedirs(args.out_dir)

    x_datalists, y_datalists = make_train_data_list(args.x_train_data_path, args.y_train_data_path) 
    tf.set_random_seed(args.random_seed) 
    x_img = tf.placeholder(tf.float32,shape=[1, args.image_size, args.image_size,3],name='x_img') 
    y_img = tf.placeholder(tf.float32,shape=[1, args.image_size, args.image_size,3],name='y_img') 

    fake_y = Encoder_Decoder(image=x_img, reuse=False, name='Encoder_Decoder') 
    dy_fake = discriminator(image=fake_y, reuse=False, name='discriminator_fake') 
    dx_real = discriminator(image=x_img, reuse=False, name='discriminator_real1') 
    dy_real = discriminator(image=y_img, reuse=False, name='discriminator_real2')

    encoder_loss = gan_loss(dy_fake, tf.ones_like(dy_fake)) + args.lamda*l1_loss(y_img, fake_y)
    dis_loss = gan_loss(dy_fake, tf.zeros_like(dy_fake)) + gan_loss(dx_real, tf.ones_like(dx_real)) + gan_loss(dy_real, tf.ones_like(dx_real))
    
    dis_loss = args.coefficient * dis_loss + (1 - args.coefficient) * args.svalue

    encoder_loss_sum = tf.summary.scalar("encoder_loss", encoder_loss)
    discriminator_sum = tf.summary.scalar("dis_loss", dis_loss)

    summary_writer = tf.summary.FileWriter(args.snapshot_dir, graph=tf.get_default_graph()) 

    g_vars = [v for v in tf.trainable_variables() if 'Encoder_Decoder' in v.name]
    d_vars = [v for v in tf.trainable_variables() if 'discriminator' in v.name]

    lr = tf.placeholder(tf.float32, None, name='learning_rate')
    d_optim = tf.train.AdamOptimizer(lr, beta1=args.beta1) 
    e_optim = tf.train.AdamOptimizer(lr, beta1=args.beta1) 
    d_grads_and_vars = d_optim.compute_gradients(dis_loss, var_list=d_vars) 
    d_train = d_optim.apply_gradients(d_grads_and_vars) 
    e_grads_and_vars = e_optim.compute_gradients(encoder_loss, var_list=g_vars) 
    e_train = e_optim.apply_gradients(e_grads_and_vars) 

    train_op = tf.group(d_train, e_train)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:

        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=100) 
        counter = 0
        number = 0
        ckpt = tf.train.get_checkpoint_state(args.snapshot_dir)  
        if ckpt and ckpt.model_checkpoint_path:        
            saver.restore(sess,ckpt.model_checkpoint_path)       
            print("Model restored...")    
        else:
            print('No Model')

        for epoch in range(args.epoch): 
           
            if number < 50000:
                if epoch < 50:
                    lrate = args.base_lr if epoch < args.epoch_step else args.base_lr*(args.epoch-epoch)/(args.epoch-args.epoch_step) 
                    for step in range(len(x_datalists)): 
                        counter += 1
                        x_image_resize, y_image_resize = TrainImageReader(x_datalists, y_datalists, step, args.image_size) 
                        batch_x_image = np.expand_dims(np.array(x_image_resize).astype(np.float32), axis = 0) 
                        batch_y_image = np.expand_dims(np.array(y_image_resize).astype(np.float32), axis = 0) 
                        feed_dict = { lr : lrate, x_img : batch_x_image, y_img : batch_y_image} 
                        encoder_loss_value, dis_loss_value, _ = sess.run([encoder_loss, dis_loss, train_op], feed_dict=feed_dict) 
                        if counter % args.save_pred_every == 0: 
                            save(saver, sess, args.snapshot_dir, counter)
                        if counter % args.summary_pred_every == 0: 
                            encoder_loss_sum_value, discriminator_sum_value = sess.run([encoder_loss_sum, discriminator_sum], feed_dict=feed_dict)
                            summary_writer.add_summary(encoder_loss_sum_value, counter)
                            summary_writer.add_summary(discriminator_sum_value, counter)
                        if counter % args.write_pred_every == 0:  
                            fake_y_value= sess.run(fake_y, feed_dict=feed_dict)
                            write_image = get_write_picture(x_image_resize, y_image_resize, fake_y_value)
                            write_image_name = args.out_dir + "/out"+ str(counter) + ".png" 
                            cv2.imwrite(write_image_name, write_image) 
                        if counter % args.human_number == 0:
                            for continued in range(args.duration):
                                args.svalue = random.uniform(0, 0.1)
                                args.coefficient = 0.3
                            number += 1
                        print('epoch {:d} step {:d} \t encoder_loss = {:.3f}, dis_loss = {:.3f}'.format(epoch, step, encoder_loss_value, dis_loss_value))

                elif epoch > 49 and epoch < 80 :
                    lrate = args.base_lr if epoch < args.epoch_step else args.base_lr*(args.epoch-epoch)/(args.epoch-args.epoch_step) 
                    for step in range(len(x_datalists)):
                        counter += 1
                        x_image_resize, y_image_resize = TrainImageReader(x_datalists, y_datalists, step, args.image_size) 
                        batch_x_image = np.expand_dims(np.array(x_image_resize).astype(np.float32), axis = 0) 
                        batch_y_image = np.expand_dims(np.array(y_image_resize).astype(np.float32), axis = 0) 
                        feed_dict = { lr : lrate, x_img : batch_x_image, y_img : batch_y_image} 
                        encoder_loss_value, dis_loss_value, _ = sess.run([encoder_loss, dis_loss, train_op], feed_dict=feed_dict) 
                        if counter % args.save_pred_every == 0: 
                            save(saver, sess, args.snapshot_dir, counter)
                        if counter % args.summary_pred_every == 0: 
                            encoder_loss_sum_value, discriminator_sum_value = sess.run([encoder_loss_sum, discriminator_sum], feed_dict=feed_dict)
                            summary_writer.add_summary(encoder_loss_sum_value, counter)
                            summary_writer.add_summary(discriminator_sum_value, counter)
                        if counter % args.write_pred_every == 0:  
                            fake_y_value= sess.run(fake_y, feed_dict=feed_dict)
                            write_image = get_write_picture(x_image_resize, y_image_resize, fake_y_value)
                            write_image_name = args.out_dir + "/out"+ str(counter) + ".png" 
                            cv2.imwrite(write_image_name, write_image)
                        if counter % args.human_number == 0:
                            for continued in range(args.duration):
                                args.svalue = random.uniform(0.2, 0.5)
                                args.coefficient = 0.3
                            number += 1
                        print('epoch {:d} step {:d} \t encoder_loss = {:.3f}, dis_loss = {:.3f}'.format(epoch, step, encoder_loss_value, dis_loss_value))

                else:
                    lrate = args.base_lr if epoch < args.epoch_step else args.base_lr*(args.epoch-epoch)/(args.epoch-args.epoch_step) 
                    for step in range(len(x_datalists)): 
                        counter += 1
                        x_image_resize, y_image_resize = TrainImageReader(x_datalists, y_datalists, step, args.image_size) 
                        batch_x_image = np.expand_dims(np.array(x_image_resize).astype(np.float32), axis = 0) 
                        batch_y_image = np.expand_dims(np.array(y_image_resize).astype(np.float32), axis = 0) 
                        feed_dict = { lr : lrate, x_img : batch_x_image, y_img : batch_y_image} 
                        encoder_loss_value, dis_loss_value, _ = sess.run([encoder_loss, dis_loss, train_op], feed_dict=feed_dict) 
                        if counter % args.save_pred_every == 0: 
                            save(saver, sess, args.snapshot_dir, counter)
                        if counter % args.summary_pred_every == 0: 
                            encoder_loss_sum_value, discriminator_sum_value = sess.run([encoder_loss_sum, discriminator_sum], feed_dict=feed_dict)
                            summary_writer.add_summary(encoder_loss_sum_value, counter)
                            summary_writer.add_summary(discriminator_sum_value, counter)
                        if counter % args.write_pred_every == 0:  
                            fake_y_value= sess.run(fake_y, feed_dict=feed_dict)
                            write_image = get_write_picture(x_image_resize, y_image_resize, fake_y_value)
                            write_image_name = args.out_dir + "/out"+ str(counter) + ".png" 
                            cv2.imwrite(write_image_name, write_image)
                        if counter % args.human_number == 0:
                            for continued in range(args.duration):
                                args.svalue = random.uniform(0.6, 0.9)
                                args.coefficient = 0.3
                            number += 1
                        print('epoch {:d} step {:d} \t encoder_loss = {:.3f}, dis_loss = {:.3f}'.format(epoch, step, encoder_loss_value, dis_loss_value))

            else:
                lrate = args.base_lr if epoch < args.epoch_step else args.base_lr*(args.epoch-epoch)/(args.epoch-args.epoch_step) 
                for step in range(len(x_datalists)): 
                    counter += 1
                    x_image_resize, y_image_resize = TrainImageReader(x_datalists, y_datalists, step, args.image_size) 
                    batch_x_image = np.expand_dims(np.array(x_image_resize).astype(np.float32), axis = 0) 
                    batch_y_image = np.expand_dims(np.array(y_image_resize).astype(np.float32), axis = 0) 
                    feed_dict = { lr : lrate, x_img : batch_x_image, y_img : batch_y_image} 
                    encoder_loss_value, dis_loss_value, _ = sess.run([encoder_loss, dis_loss, train_op], feed_dict=feed_dict) 
                    if counter % args.save_pred_every == 0: 
                        save(saver, sess, args.snapshot_dir, counter)
                    if counter % args.summary_pred_every == 0: 
                        encoder_loss_sum_value, discriminator_sum_value = sess.run([encoder_loss_sum, discriminator_sum], feed_dict=feed_dict)
                        summary_writer.add_summary(encoder_loss_sum_value, counter)
                        summary_writer.add_summary(discriminator_sum_value, counter)
                    if counter % args.write_pred_every == 0:  
                        fake_y_value= sess.run(fake_y, feed_dict=feed_dict)
                        write_image = get_write_picture(x_image_resize, y_image_resize, fake_y_value)
                        write_image_name = args.out_dir + "/out"+ str(counter) + ".png" 
                        cv2.imwrite(write_image_name, write_image)
                    print('epoch {:d} step {:d} \t encoder_loss = {:.3f}, dis_loss = {:.3f}'.format(epoch, step, encoder_loss_value, dis_loss_value))
示例#4
0
def main():
    vis.vis.texts=''
    dice_loss_meter =tnt.meter.AverageValueMeter()
    image_batch=tf.placeholder(tf.float32, shape=[None, 48, 48,48, 1])
    label_batch=tf.placeholder(tf.float32, shape=[None,2])
    net=Classifier({'data': image_batch},batch_size=batch_size)
    prob = net.layers['result']
    logits=net.layers['logits']
    dataset=LungDataset("/home/x/dcsb/data/TianChi",augument=True)
    
    all_trainable =tf.trainable_variables()
    restore_var = tf.global_variables()
    
    cross_loss = tf.losses.softmax_cross_entropy(label_batch,logits)
    
    global iters
    cross_loss_sum=tf.summary.scalar("crossloss",cross_loss)
    # accuracy=tf.metrics.accuracy(label_batch,prob)
    optimiser = tf.train.MomentumOptimizer(0.01,0.99)
    gradients = tf.gradients(cross_loss, all_trainable)
    clipped_gradients, norm = tf.clip_by_global_norm(gradients,1.)
    train_op = optimiser.apply_gradients(zip(clipped_gradients, all_trainable))
    
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()
    
    sess.run(init)
    all_sum=tf.summary.merge([cross_loss_sum])
    summary_writer = tf.summary.FileWriter(logs,graph=tf.get_default_graph())
    saver = tf.train.Saver(var_list=restore_var, max_to_keep=40)
    
    # Load variables if the checkpoint is provided.
    if restore_from is not None:
        loader = tf.train.Saver(var_list=restore_var)
        load(loader, sess, restore_from,"classifier_v2")
    
    for i in range(max_run):
        dice_loss_meter.reset()
        start_time = time.time()
        labels=np.array([1,0])
        labels=labels[np.newaxis,:]
        pred=np.array([1,0])
        pred=pred[np.newaxis,:]
        train_loader = DataLoader(dataset,batch_size = batch_size,shuffle = True,num_workers = 1,pin_memory=True,drop_last=True)
        for batch_idx, (img_, label_,_) in enumerate(train_loader):
            iters+=1
            img=img_.numpy()
            label=label_.numpy()
            labels=np.concatenate([labels,label],axis=0)
            img=img.transpose([0,2,3,4,1])
            feed_dict={image_batch:img,label_batch:label}
            _,cross_loss_,probs,summary=sess.run([train_op,cross_loss,prob,all_sum],feed_dict=feed_dict)
            summary_writer.add_summary(summary, iters)
            pred=np.concatenate([pred,probs],axis=0)
            # print "prob+:",probs[:,0]
            vis.plot('accuracy',np.mean(np.argmax(labels,axis=1)==np.argmax(pred,axis=1)))
            dice_loss_meter.add(cross_loss_)
            if batch_idx>10:
                try:
                    vis.plot('cross_loss',dice_loss_meter.value()[0])
                except:
                    pass
            vis.img('input',img_[0,0,24,:,:].cpu().float())
            if iters%50==0:
                
                pred_=np.argmax(pred,axis=1)
                label_=np.argmax(labels,axis=1)
                acc=np.mean(label_==pred_)
                cross=cross_loss.eval(feed_dict,session=sess)
                print("Epoch: [%2d]  [%4d] ,time: %4.4f,cross_loss:%.8f,accuracy:%.8f"% \
                      (i,batch_idx,time.time() - start_time,cross,acc))
            
        if i%2==0:
            save(saver,sess,models_path,iters,"classifier_v2",train_tag="nodule_predict")
def main():

    train_vids, test_vids = data_util.load_dataset(args)
    iters = args.iters
    prefix = ("sto"
              + "_h=" + str(args.image_size_h)
              + "_w=" + str(args.image_size_w)
              + "_K=" + str(args.K)
              + "_T=" + str(args.T)
              + "_B=" + str(args.B)
              + "_batch_size=" + str(args.batch_size)
              + "_beta1=" + str(args.beta1)
              + "_alpha=" + str(args.alpha)
              + "_gamma=" + str(args.gamma)
              + "_lr=" + str(args.lr)
              + "_mode=" + str(args.mode)
              + "_space_aware=" + str(space_aware)
              + "_z_channel=" + str(args.z_channel)
              + "_p_loss=" + str(args.pixel_loss)
              + "_cell_type=" + str(args.cell_type)
              + "_norm=" + str(not args.no_normalized)
              + "_mask_w=" + str(args.mask_weight)
              + "_res_type=" + str(args.res_type)
              + "_neg_noise=" + str(not args.no_negative_noise)
              + "_res_ref=" + str(not args.no_res_ref)
              + "_pic_norm=" + str(not args.no_pic_norm)
              + "_start_perc=" + str(args.start_percentage)
    )

    print("\n" + prefix + "\n")
    checkpoint_dir = "../../models/stochastic/" + args.dataset + '/' + prefix + "/"
    samples_dir = "../../samples/stochastic/" + args.dataset + '/' + prefix + "/"
    summary_dir = "../../logs/stochastic/" + args.dataset + '/' + prefix + "/"

    if not exists(checkpoint_dir):
        makedirs(checkpoint_dir)
    #     save synthesized frame sample
    if not exists(samples_dir):
        makedirs(samples_dir)
    if not exists(summary_dir):
        makedirs(summary_dir)

    device_string = ""
    if args.cpu:
        os.environ["CUDA_VISIBLE_DEVICES"] = "0"
        device_string = "/cpu:0"
    elif args.gpu:
        device_string = "/gpu:%d" % args.gpu[0]
        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu[0])

    with tf.device(device_string):
        if args.mode == "bi_sto":
            model = stochastic_bi_net([args.image_size_h, args.image_size_w], batch_size = args.batch_size,
                       c_dim = args.color_channel_num, K=args.K, T=args.T, B=args.B, debug = False,
                       pixel_loss = args.pixel_loss, convlstm_kernel = [3, 3], mode = args.mode,
                       space_aware = space_aware, cell_type=args.cell_type, z_channel = args.z_channel,
                       normalize = not args.no_normalized, weight=args.mask_weight, res_type=args.res_type,
                       negative_noise = not args.no_negative_noise, res_ref = not args.no_res_ref,
                       pic_norm = not args.no_pic_norm)
        elif args.mode == "learned_prior":
            model = stochastic_learned_prior([args.image_size_h, args.image_size_w], batch_size = args.batch_size,
                       c_dim = args.color_channel_num, K=args.K, T=args.T, B=args.B, debug = False,
                       pixel_loss = args.pixel_loss, convlstm_kernel = [3, 3], mode = args.mode,
                       space_aware = space_aware, cell_type=args.cell_type, z_channel = args.z_channel,
                       normalize = not args.no_normalized, weight=args.mask_weight, res_type=args.res_type,
                       negative_noise = not args.no_negative_noise, res_ref = not args.no_res_ref,
                       pic_norm = not args.no_pic_norm)
        elif args.mode == "deter_flexible":
            model = deter_flexible([args.image_size_h, args.image_size_w], batch_size = args.batch_size,
                       c_dim = args.color_channel_num, K=args.K, T=args.T, B=args.B, debug = False,
                       pixel_loss = args.pixel_loss, convlstm_kernel = [3, 3], mode = args.mode,
                       space_aware = space_aware, cell_type=args.cell_type, z_channel = args.z_channel,
                       normalize = not args.no_normalized, weight=args.mask_weight, res_type=args.res_type,
                       negative_noise = not args.no_negative_noise, res_ref = not args.no_res_ref,
                       pic_norm = not args.no_pic_norm)
        global_step = tf.Variable(0, trainable=False)
        global_rate = tf.train.exponential_decay(args.lr, global_step,
                     args.decay_step, args.decay_rate, staircase=True)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            g_full = model.L_train_p + args.alpha * model.L_train_kl
            if args.gamma != 0:
                g_full += args.gamma * model.L_train_kl_exlusive
            g_optim = tf.train.AdamOptimizer(global_rate, beta1=args.beta1).minimize(
                g_full, var_list=model.trainable_variables, global_step=global_step
            )
    config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        tf.global_variables_initializer().run()
        if args.load_pretrain:
            if ops.load(model, sess, checkpoint_dir):
                print(" [*] Load SUCCESS")
            else:
                print(" [!] Load failed...")
        train_sum = tf.summary.merge([model.L_train_p_sum, model.L_train_kl_sum, model.L_train_p_l1_diff_sum,
                          model.L_trainTest_p_l1_diff_sum])
        test_sum = tf.summary.merge([model.L_test_p_sum, model.L_test_p_l1_diff_sum,
                                     model.L_testTrain_p_l1_diff_sum])

        writer = tf.summary.FileWriter(summary_dir, sess.graph)
        start_time = time.time()
        full_train = True
        # if (len(args.pretrain_model) > 0):
        #     # Create a saver. include gen_vars and encoder_vars
        #     model.saver.restore(sess, args.pretrain_model)
        blank = None
        p_loss_percentage = 1.0
        flipable = False
        if args.dataset=="kth":
            flipable = True
        blank1 = None
        while iters <= args.num_iter:
            mini_batches = get_minibatches_idx(len(train_vids), args.batch_size, shuffle=True)
            for _, batchidx in mini_batches:
                if args.start_percentage == 0.0:
                    p_loss_percentage = 0.5
                else:
                    if iters >= (args.num_iter * args.start_percentage):
                        if iters < args.num_iter * (1 - args.start_percentage):
                            p_loss_percentage = 1 - 0.6 * (
                                    (1.0 * iters / args.num_iter - args.start_percentage)
                                    / (1.0 - 2 * args.start_percentage))
                        else:
                            p_loss_percentage = 0.4
                if iters > args.num_iter: break
                if len(batchidx) == args.batch_size:
                    sess.run(tf.get_collection('update_dup'))
                    # batch, time, height, width, color
                    ref_batch, inf_batch = load_stochastic_data_from_list(train_vids, batchidx,
                                args.image_size_h, args.image_size_w, args.K, args.T,
                                  args.B, flipable=flipable, channel=args.color_channel_num)
                    if args.debug:
                        print ref_batch.shape, inf_batch.shape
                    _, summary_str , L_train_p, L_train_kl\
                        = sess.run([g_optim, train_sum, model.L_train_p, model.L_train_kl],
                                    feed_dict={model.ref_seq: ref_batch,
                                               model.inf_seq: inf_batch,
                                               model.is_train: True,
                                               model.p_loss_percentage: p_loss_percentage})
                    if not args.no_store: writer.add_summary(summary_str, iters)
                    print(
                        "Iters: [%2d] time: %4.4f, L_train_p: %.8f, L_train_kl: %.8f"
                            % (iters, time.time() - start_time, L_train_p, L_train_kl)
                    )

                    if np.mod(iters, 2500) == 0:
                        print("validation at iters:", iters)
                        ref_batch_train, inf_batch_train = load_stochastic_data_from_list(train_vids,
                                              range(3, 3 + args.batch_size/2)+range(60, 60 + args.batch_size/2),
                                              args.image_size_h, args.image_size_w,
                                              args.K, args.T, args.B, flipable=flipable, channel=args.color_channel_num)

                        ref_batch_test, inf_batch_test = load_stochastic_data_from_list(test_vids,
                                              range(3, 3 + args.batch_size/2)+range(60, 60 + args.batch_size/2),
                                              args.image_size_h,
                                              args.image_size_w,
                                              args.K, args.T, args.B, flipable=flipable, channel=args.color_channel_num)
                        if blank1 is None:
                            blank1 = np.zeros_like(ref_batch_train[0, :args.B // 2 + 1, ...])
                            blank2 = np.zeros_like(ref_batch_train[0, args.B//2+1: , ...])
                        summary_test, L_test_p, L_test_kl, \
                        G_test, G_test_post, test_mask_binary, last_frame_test = sess.run(
                            [test_sum, model.L_train_p, model.L_train_kl, model.G_real,
                             model.G_post_real, model.mask_binary, model.last_frame],
                            feed_dict={model.ref_seq: ref_batch_test,
                                       model.inf_seq: inf_batch_test,
                                       model.is_train: False,
                                       model.p_loss_percentage: p_loss_percentage})

                        _, _, _, _, _, mean_batch_psnr_test_post, mean_batch_ssim_test_post\
                            = metrics.cal_seq(inf_batch_test[:, 1:-1, ...], G_test_post)
                        _, _, _, _, _, mean_batch_psnr_test, mean_batch_ssim_test \
                            = metrics.cal_seq(inf_batch_test[:, 1:-1, ...], G_test)

                        writer.add_summary(summary_test, iters)
                        print(
                            "Iters: [%2d] time: %4.4f, L_test_p: %.8f, L_test_kl: %.8f"
                                % (iters, time.time() - start_time, L_test_p, L_test_kl)
                        )
                        print("ref_batch_test.min, ref_batch_test.max", np.min(ref_batch_test), np.max(ref_batch_test))
                        print("mean_batch_psnr_test_post, mean_batch_ssim_test_post",
                              mean_batch_psnr_test_post, mean_batch_ssim_test_post)
                        print("mean_batch_psnr_test, mean_batch_ssim_test",
                              mean_batch_psnr_test, mean_batch_ssim_test)
                        print "test G_test.shape", G_test.shape
                        summary_train, L_train_p, L_train_kl, G_train, \
                        G_train_post, train_mask_binary, last_frame_train = sess.run(
                            [train_sum, model.L_train_p, model.L_train_kl, model.G_real, model.G_post_real,
                             model.mask_binary, model.last_frame],
                            feed_dict={model.ref_seq: ref_batch_train,
                                       model.inf_seq: inf_batch_train,
                                       model.is_train: True,
                                       model.p_loss_percentage: p_loss_percentage})

                        _, _, _, _, _, mean_batch_psnr_train_post, mean_batch_ssim_train_post \
                            = metrics.cal_seq(inf_batch_train[:, 1:-1, ...], G_train_post)
                        _, _, _, _, _, mean_batch_psnr_train, mean_batch_ssim_train \
                            = metrics.cal_seq(inf_batch_train[:, 1:-1, ...], G_train)
                        print("mean_batch_psnr_train_post, mean_batch_ssim_train_post",
                              mean_batch_psnr_train_post, mean_batch_ssim_train_post)
                        print("mean_batch_psnr_train, mean_batch_ssim_train",
                              mean_batch_psnr_train, mean_batch_ssim_train)
                        for i in [1, args.batch_size/2 ,args.batch_size - 1]:
                            sample_train = depth_to_width(np.concatenate(
                                (ref_batch_train[i,:args.B//2,...],
                                 inf_batch_train[i,...], ref_batch_train[i,args.B//2+2:,...]), axis=0))
                            gen_train_mask = depth_to_width(np.concatenate(
                                (blank1, train_mask_binary[i, ...], blank2),axis=0))
                            gen_train_post = depth_to_width(np.concatenate(
                                (blank1, G_train_post[i, ...], blank2), axis=0))
                            gen_train = depth_to_width(np.concatenate(
                                (blank1, G_train[i, ...], blank2),axis=0))
                            sample_test = depth_to_width(np.concatenate(
                                (ref_batch_test[i,:args.B//2,...],
                                 inf_batch_test[i,...], ref_batch_test[i,args.B//2+2:,...]),axis=0))
                            gen_test_mask = depth_to_width(np.concatenate(
                                (blank1, test_mask_binary[i, ...], blank2), axis=0))
                            gen_test_post = depth_to_width(np.concatenate(
                                (blank1, G_test_post[i, ...], blank2), axis=0))
                            gen_test = depth_to_width(np.concatenate(
                                (blank1, G_test[i, ...], blank2),axis=0))
                            if i == 1:
                                print sample_train.shape, gen_train.shape, sample_train.shape
                                sample_train_cat = np.concatenate((sample_train, gen_train_mask, gen_train_post, gen_train), axis=0)
                                sample_test_cat = np.concatenate((sample_test, gen_test_mask, gen_test_post, gen_test), axis=0)
                            else:
                                sample_train_cat = np.concatenate(
                                    (sample_train_cat, sample_train, gen_train_mask, gen_train_post, gen_train), axis=0)
                                sample_test_cat = np.concatenate(
                                    (sample_test_cat, sample_test, gen_test_mask, gen_test_post, gen_test), axis=0)
                        print("Saving sample at iter"), iters
                        img_summary = sess.run(model.summary_merge_seq_img, feed_dict={
                            model.train_seq_img: np.expand_dims(image_clipping(sample_train_cat), axis=0),
                            model.test_seq_img: np.expand_dims(image_clipping(sample_test_cat), axis=0)
                        })
                        metrics_summary = sess.run(
                            model.summary_merge_metrics, feed_dict={
                                model.mean_batch_psnr_test_post: mean_batch_psnr_test_post,
                                model.mean_batch_psnr_test: mean_batch_psnr_test,
                                model.mean_batch_psnr_train_post: mean_batch_psnr_train_post,
                                model.mean_batch_psnr_train: mean_batch_psnr_train,
                                model.mean_batch_ssim_test_post: mean_batch_ssim_test_post,
                                model.mean_batch_ssim_test: mean_batch_ssim_test,
                                model.mean_batch_ssim_train_post: mean_batch_ssim_train_post,
                                model.mean_batch_ssim_train: mean_batch_ssim_train
                            }
                        )
                        if not args.no_store:
                            writer.add_summary(img_summary, iters)
                            writer.add_summary(metrics_summary, iters)
                    if np.mod(iters, 10000) == 0 and iters != 0 and not args.no_store:
                        ops.save(model, sess, checkpoint_dir, iters)
                iters += 1
        print "finish Training"