Beispiel #1
0
class GANTrainer(object):
    def __init__(self, output_dir):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            self.drive_log_dir = '/content/drive/My Drive/coco2Log'
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)
            self.drive_summary_writer = FileWriter(self.drive_log_dir)

        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        self.batch_size = cfg.TRAIN.BATCH_SIZE
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True

    # ############# For training stageI GAN #############
    def load_network_stageI(self):
        from model import STAGE1_G, STAGE1_D
        netG = STAGE1_G()
        netG.apply(weights_init)
        print(netG)
        netD = STAGE1_D()
        netD.apply(weights_init)
        print(netD)

        print('Your now at this Locatin : ')

        if cfg.NET_G != '':
            state_dict = \
                torch.load(cfg.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict["netG"])
            print('Load from: ', cfg.NET_G)
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,  map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict["netD"])
            print('Load from: ', cfg.NET_D)
        if cfg.CUDA:
            netG.cuda()
            netD.cuda()
        return netG, netD

    # ############# For training stageII GAN  #############
    def load_network_stageII(self):
        from model import STAGE1_G, STAGE2_G, STAGE2_D

        Stage1_G = STAGE1_G()
        netG = STAGE2_G(Stage1_G)
        netG.apply(weights_init)
        print(netG)
        print(
            "---------------------------------------------------------------------------------"
        )
        print("Current Working Directory : ", os.getcwd())
        print("cfg.NET_G : ", cfg.NET_G)
        print(
            "---------------------------------------------------------------------------------"
        )
        if cfg.NET_G != '':
            print("THIS THE CURRENT DIRECTORY : ", os.getcwd())
            state_dict = torch.load(cfg.NET_G,
                                    map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict["netG"])
            print('Load from: ', cfg.NET_G)
        elif cfg.STAGE1_G != '':
            state_dict = torch.load(cfg.STAGE1_G,
                                    map_location=lambda storage, loc: storage)
            netG.STAGE1_G.load_state_dict(state_dict["netG"])
            print('Load from: ', cfg.STAGE1_G)
        else:
            print("Please give the Stage1_G path")
            return

        netD = STAGE2_D()
        netD.apply(weights_init)
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,
                           map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict["netD"])
            print('Load from: ', cfg.NET_D)
        print(netD)

        if cfg.CUDA:
            netG.cuda()
            netD.cuda()
        return netG, netD

    def train(self, data_loader, stage=1, max_objects=3):
        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        # with torch.no_grad():
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                               requires_grad=False)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH

        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerD = optim.Adam(netD.parameters(),
                                lr=cfg.TRAIN.DISCRIMINATOR_LR,
                                betas=(0.5, 0.999))
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))
        ####
        startpoint = -1
        if cfg.NET_G != '':
            state_dict = torch.load(cfg.NET_G,
                                    map_location=lambda storage, loc: storage)
            optimizerD.load_state_dict(state_dict["optimD"])
            optimizerG.load_state_dict(state_dict["optimG"])
            startpoint = state_dict["epoch"]
            print(startpoint)
            print('Load Optim and optimizers as : ', cfg.NET_G)
        ####

        count = 0
        drive_count = 0
        for epoch in range(startpoint + 1, self.max_epoch):
            print('epoch : ', epoch, ' drive_count : ', drive_count)
            epoch_start_time = time.time()
            print(epoch)
            start_t = time.time()
            start_t500 = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            time_to_i = time.time()
            for i, data in enumerate(data_loader, 0):
                # if i >= 3360 :
                #     print ('Last Batches : ' , i)
                # if i < 10 :
                #     print ('first Batches : ' , i)
                # if i == 0 :
                #     print ('Startig! Batch ',i,'from total of 2070' )
                # if i % 10 == 0 and i!=0:
                #     end_t500 = time.time()
                #     print ('Batch Number : ' , i ,' |||||  Toatal Time : ' , (end_t500 - start_t500))
                #     start_t500 = time.time()
                ######################################################
                # (1) Prepare training data
                # if i < 10 :
                #     print (" (1) Prepare training data for batch : " , i)
                ######################################################
                #print ("Prepare training data for batch : " , i)
                real_img_cpu, bbox, label, txt_embedding = data

                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    if cfg.STAGE == 1:
                        bbox = bbox.cuda()
                    elif cfg.STAGE == 2:
                        bbox = [bbox[0].cuda(), bbox[1].cuda()]
                    label = label.cuda()
                    txt_embedding = txt_embedding.cuda()

                if cfg.STAGE == 1:
                    bbox = bbox.view(-1, 4)
                    transf_matrices_inv = compute_transformation_matrix_inverse(
                        bbox)
                    transf_matrices_inv = transf_matrices_inv.view(
                        real_imgs.shape[0], max_objects, 2, 3)
                    transf_matrices = compute_transformation_matrix(bbox)
                    transf_matrices = transf_matrices.view(
                        real_imgs.shape[0], max_objects, 2, 3)
                elif cfg.STAGE == 2:
                    _bbox = bbox[0].view(-1, 4)
                    transf_matrices_inv = compute_transformation_matrix_inverse(
                        _bbox)
                    transf_matrices_inv = transf_matrices_inv.view(
                        real_imgs.shape[0], max_objects, 2, 3)

                    _bbox = bbox[1].view(-1, 4)
                    transf_matrices_inv_s2 = compute_transformation_matrix_inverse(
                        _bbox)
                    transf_matrices_inv_s2 = transf_matrices_inv_s2.view(
                        real_imgs.shape[0], max_objects, 2, 3)
                    transf_matrices_s2 = compute_transformation_matrix(_bbox)
                    transf_matrices_s2 = transf_matrices_s2.view(
                        real_imgs.shape[0], max_objects, 2, 3)

                # produce one-hot encodings of the labels
                _labels = label.long()
                # remove -1 to enable one-hot converting
                _labels[_labels < 0] = 80
                if cfg.CUDA:
                    label_one_hot = torch.cuda.FloatTensor(
                        noise.shape[0], max_objects, 81).fill_(0)
                else:
                    label_one_hot = torch.FloatTensor(noise.shape[0],
                                                      max_objects, 81).fill_(0)
                label_one_hot = label_one_hot.scatter_(2, _labels, 1).float()

                #######################################################
                # # (2) Generate fake images
                # if i < 10 :
                #     print ("(2)Generate fake images")
                ######################################################

                noise.data.normal_(0, 1)
                if cfg.STAGE == 1:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              label_one_hot)
                elif cfg.STAGE == 2:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              transf_matrices_s2, transf_matrices_inv_s2,
                              label_one_hot)
                if cfg.CUDA:
                    _, fake_imgs, mu, logvar, _ = nn.parallel.data_parallel(
                        netG, inputs, self.gpus)
                else:
                    print('Hiiiiiiiiiiii')
                    _, fake_imgs, mu, logvar, _ = netG(txt_embedding, noise,
                                                       transf_matrices_inv,
                                                       label_one_hot)
                # _, fake_imgs, mu, logvar, _ = netG(txt_embedding, noise, transf_matrices_inv, label_one_hot)

                ############################
                # # (3) Update D network
                # if i < 10 :
                #     print("(3) Update D network")
                ###########################
                netD.zero_grad()

                if cfg.STAGE == 1:
                    errD, errD_real, errD_wrong, errD_fake = \
                        compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                                   real_labels, fake_labels,
                                                   label_one_hot, transf_matrices, transf_matrices_inv,
                                                   mu, self.gpus)
                elif cfg.STAGE == 2:
                    errD, errD_real, errD_wrong, errD_fake = \
                        compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                                   real_labels, fake_labels,
                                                   label_one_hot, transf_matrices_s2, transf_matrices_inv_s2,
                                                   mu, self.gpus)
                errD.backward(retain_graph=True)
                optimizerD.step()
                ############################
                # # (4) Update G network
                # if i < 10 :
                #     print ("(4) Update G network")
                ###########################
                netG.zero_grad()
                # if i < 10 :
                #     print ("netG.zero_grad")
                if cfg.STAGE == 1:
                    errG = compute_generator_loss(netD, fake_imgs, real_labels,
                                                  label_one_hot,
                                                  transf_matrices,
                                                  transf_matrices_inv, mu,
                                                  self.gpus)
                elif cfg.STAGE == 2:
                    # if i < 10 :
                    #     print ("cgf.STAGE = " , cfg.STAGE)
                    errG = compute_generator_loss(netD, fake_imgs, real_labels,
                                                  label_one_hot,
                                                  transf_matrices_s2,
                                                  transf_matrices_inv_s2, mu,
                                                  self.gpus)
                    # if i < 10 :
                    #     print("errG : ",errG)
                kl_loss = KL_loss(mu, logvar)
                # if i < 10 :
                #     print ("kl_loss = " , kl_loss)
                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL
                # if i < 10 :
                #     print (" errG_total = " , errG_total )
                errG_total.backward()
                # if i < 10 :
                #     print ("errG_total.backward() ")
                optimizerG.step()
                # if i < 10 :
                #     print ("optimizerG.step() " )

                #print (" i % 500 == 0 :  " , i % 500 == 0 )
                end_t = time.time()
                #print ("batch time : " , (end_t - start_t))
                if i % 500 == 0:
                    #print (" i % 500 == 0" , i % 500 == 0 )
                    count += 1
                    summary_D = summary.scalar('D_loss', errD.item())
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.item())
                    summary_KL = summary.scalar('KL_loss', kl_loss.item())

                    print('epoch     :  ', epoch)
                    print('count     :  ', count)
                    print('  i       :  ', i)
                    print('Time to i : ', time.time() - time_to_i)
                    time_to_i = time.time()
                    print('D_loss : ', errD.item())
                    print('D_loss_real : ', errD_real)
                    print('D_loss_wrong : ', errD_wrong)
                    print('D_loss_fake : ', errD_fake)
                    print('G_loss : ', errG.item())
                    print('KL_loss : ', kl_loss.item())
                    print('generator_lr : ', generator_lr)
                    print('discriminator_lr : ', discriminator_lr)
                    print('lr_decay_step : ', lr_decay_step)

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                    # save the image result for each epoch
                    with torch.no_grad():
                        if cfg.STAGE == 1:
                            inputs = (txt_embedding, noise,
                                      transf_matrices_inv, label_one_hot)
                        elif cfg.STAGE == 2:
                            inputs = (txt_embedding, noise,
                                      transf_matrices_inv, transf_matrices_s2,
                                      transf_matrices_inv_s2, label_one_hot)

                        if cfg.CUDA:
                            lr_fake, fake, _, _, _ = nn.parallel.data_parallel(
                                netG, inputs, self.gpus)
                        else:
                            lr_fake, fake, _, _, _ = netG(
                                txt_embedding, noise, transf_matrices_inv,
                                label_one_hot)

                        save_img_results(real_img_cpu, fake, epoch,
                                         self.image_dir)
                        if lr_fake is not None:
                            save_img_results(None, lr_fake, epoch,
                                             self.image_dir)
                if i % 100 == 0:
                    drive_count += 1
                    self.drive_summary_writer.add_summary(
                        summary_D, drive_count)
                    self.drive_summary_writer.add_summary(
                        summary_D_r, drive_count)
                    self.drive_summary_writer.add_summary(
                        summary_D_w, drive_count)
                    self.drive_summary_writer.add_summary(
                        summary_D_f, drive_count)
                    self.drive_summary_writer.add_summary(
                        summary_G, drive_count)
                    self.drive_summary_writer.add_summary(
                        summary_KL, drive_count)

            #print (" with torch.no_grad(): "  )
            with torch.no_grad():
                if cfg.STAGE == 1:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              label_one_hot)
                elif cfg.STAGE == 2:
                    #print (" cfg.STAGE == 2: " , cfg.STAGE == 2 )
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              transf_matrices_s2, transf_matrices_inv_s2,
                              label_one_hot)
                    #print (" inputs " , inputs )
                lr_fake, fake, _, _, _ = nn.parallel.data_parallel(
                    netG, inputs, self.gpus)
                #print (" lr_fake, fake " , lr_fake, fake )
                save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                #print (" save_img_results(real_img_cpu, fake, epoch, self.image_dir) " , )

                #print (" lr_fake is not None: " , lr_fake is not None )
                if lr_fake is not None:
                    save_img_results(None, lr_fake, epoch, self.image_dir)
                    #print (" save_img_results(None, lr_fake, epoch, self.image_dir) " )
                    #end_t = time.time()
                    #print ("batch time : " , (end_t - start_t))
            end_t = time.time()
            print(
                '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  ''' %
                (epoch, self.max_epoch, i, len(data_loader), errD.item(),
                 errG.item(), kl_loss.item(), errD_real, errD_wrong, errD_fake,
                 (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, optimizerG, optimizerD, epoch,
                           self.model_dir)

            print("keyTime |||||||||||||||||||||||||||||||")
            print("epoch_time : ", time.time() - epoch_start_time)
            print("KeyTime |||||||||||||||||||||||||||||||")

        #
        save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir)
        #
        self.summary_writer.close()

    def sample(self,
               datapath,
               num_samples=25,
               stage=1,
               draw_bbox=True,
               max_objects=3):
        from PIL import Image, ImageDraw, ImageFont
        import cPickle as pickle
        import torchvision
        import torchvision.utils as vutils
        img_dir = cfg.IMG_DIR
        if stage == 1:
            netG, _ = self.load_network_stageI()
        else:
            netG, _ = self.load_network_stageII()
        netG.eval()

        # Load text embeddings generated from the encoder
        t_file = torchfile.load(datapath + "val_captions.t7")
        captions_list = t_file.raw_txt
        embeddings = np.concatenate(t_file.fea_txt, axis=0)
        num_embeddings = len(captions_list)
        label, bbox = load_validation_data(datapath)

        filepath = os.path.join(datapath, 'filenames.pickle')
        with open(filepath, 'rb') as f:
            filenames = pickle.load(f)
        print('Successfully load sentences from: ', datapath)
        print('Total number of sentences:', num_embeddings)
        # path to save generated samples
        save_dir = cfg.NET_G[:cfg.NET_G.find('.pth')] + "_visualize_bbox"
        print("saving to:", save_dir)
        mkdir_p(save_dir)

        if cfg.CUDA:
            if cfg.STAGE == 1:
                bbox = bbox.cuda()
            elif cfg.STAGE == 2:
                bbox = [bbox.clone().cuda(), bbox.cuda()]
            label = label.cuda()

        #######################################
        if cfg.STAGE == 1:
            bbox_ = bbox.clone()
        elif cfg.STAGE == 2:
            bbox_ = bbox[0].clone()

        if cfg.STAGE == 1:
            bbox = bbox.view(-1, 4)
            transf_matrices_inv = compute_transformation_matrix_inverse(bbox)
            transf_matrices_inv = transf_matrices_inv.view(
                num_embeddings, max_objects, 2, 3)
        elif cfg.STAGE == 2:
            _bbox = bbox[0].view(-1, 4)
            transf_matrices_inv = compute_transformation_matrix_inverse(_bbox)
            transf_matrices_inv = transf_matrices_inv.view(
                num_embeddings, max_objects, 2, 3)

            _bbox = bbox[1].view(-1, 4)
            transf_matrices_inv_s2 = compute_transformation_matrix_inverse(
                _bbox)
            transf_matrices_inv_s2 = transf_matrices_inv_s2.view(
                num_embeddings, max_objects, 2, 3)
            transf_matrices_s2 = compute_transformation_matrix(_bbox)
            transf_matrices_s2 = transf_matrices_s2.view(
                num_embeddings, max_objects, 2, 3)

        # produce one-hot encodings of the labels
        _labels = label.long()
        # remove -1 to enable one-hot converting
        _labels[_labels < 0] = 80
        label_one_hot = torch.cuda.FloatTensor(num_embeddings, max_objects,
                                               81).fill_(0)
        label_one_hot = label_one_hot.scatter_(2, _labels, 1).float()
        #######################################

        nz = cfg.Z_DIM
        noise = Variable(torch.FloatTensor(9, nz))
        if cfg.CUDA:
            noise = noise.cuda()

        imsize = 64 if stage == 1 else 256

        for count in range(num_samples):
            index = int(np.random.randint(0, num_embeddings, 1))
            key = filenames[index]
            img_name = img_dir + "/" + key + ".jpg"
            img = Image.open(img_name).convert('RGB').resize((imsize, imsize),
                                                             Image.ANTIALIAS)
            val_image = torchvision.transforms.functional.to_tensor(img)
            val_image = val_image.view(1, 3, imsize, imsize)
            val_image = (val_image - 0.5) * 2

            embeddings_batch = embeddings[index]
            transf_matrices_inv_batch = transf_matrices_inv[index]
            label_one_hot_batch = label_one_hot[index]

            embeddings_batch = np.reshape(embeddings_batch,
                                          (1, 1024)).repeat(9, 0)
            transf_matrices_inv_batch = transf_matrices_inv_batch.view(
                1, 3, 2, 3).repeat(9, 1, 1, 1)
            label_one_hot_batch = label_one_hot_batch.view(1, 3,
                                                           81).repeat(9, 1, 1)

            if cfg.STAGE == 2:
                transf_matrices_s2_batch = transf_matrices_s2[index]
                transf_matrices_s2_batch = transf_matrices_s2_batch.view(
                    1, 3, 2, 3).repeat(9, 1, 1, 1)
                transf_matrices_inv_s2_batch = transf_matrices_inv_s2[index]
                transf_matrices_inv_s2_batch = transf_matrices_inv_s2_batch.view(
                    1, 3, 2, 3).repeat(9, 1, 1, 1)

            txt_embedding = Variable(torch.FloatTensor(embeddings_batch))
            if cfg.CUDA:
                label_one_hot_batch = label_one_hot_batch.cuda()
                txt_embedding = txt_embedding.cuda()

            #######################################################
            # (2) Generate fake images
            ######################################################
            noise.data.normal_(0, 1)
            # inputs = (txt_embedding, noise, transf_matrices_inv_batch, label_one_hot_batch)
            if cfg.STAGE == 1:
                inputs = (txt_embedding, noise, transf_matrices_inv_batch,
                          label_one_hot_batch)
            elif cfg.STAGE == 2:
                inputs = (txt_embedding, noise, transf_matrices_inv_batch,
                          transf_matrices_s2_batch,
                          transf_matrices_inv_s2_batch, label_one_hot_batch)
            with torch.no_grad():
                _, fake_imgs, mu, logvar, _ = nn.parallel.data_parallel(
                    netG, inputs, self.gpus)

            data_img = torch.FloatTensor(10, 3, imsize, imsize).fill_(0)
            data_img[0] = val_image
            data_img[1:10] = fake_imgs

            if draw_bbox:
                for idx in range(3):
                    x, y, w, h = tuple(
                        [int(imsize * x) for x in bbox_[index, idx]])
                    w = imsize - 1 if w > imsize - 1 else w
                    h = imsize - 1 if h > imsize - 1 else h
                    if x <= -1:
                        break
                    data_img[:10, :, y, x:x + w] = 1
                    data_img[:10, :, y:y + h, x] = 1
                    data_img[:10, :, y + h, x:x + w] = 1
                    data_img[:10, :, y:y + h, x + w] = 1

            vutils.save_image(data_img,
                              '{}/{}.png'.format(save_dir,
                                                 captions_list[index]),
                              normalize=True,
                              nrow=10)

        print("Saved {} files to {}".format(count + 1, save_dir))
Beispiel #2
0
class Experiment(SummaryWriter):
    def __init__(self,
                 save_dir=None,
                 name='default',
                 debug=False,
                 version=None,
                 autosave=False,
                 description=None,
                 create_git_tag=False,
                 rank=0,
                 *args,
                 **kwargs):
        """
        A new Experiment object defaults to 'default' unless a specific name is provided
        If a known name is already provided, then the file version is changed
        :param name:
        :param debug:
        """

        # change where the save dir is if requested

        if save_dir is not None:
            global _ROOT
            _ROOT = save_dir

        self.save_dir = save_dir
        self.tag_markdown_saved = False
        self.no_save_dir = save_dir is None
        self.metrics = []
        self.tags = {}
        self.name = name
        self.debug = debug
        self.version = version
        self.autosave = autosave
        self.description = description
        self.create_git_tag = create_git_tag
        self.exp_hash = '{}_v{}'.format(self.name, version)
        self.created_at = str(datetime.utcnow())
        self.rank = rank
        self.process = os.getpid()

        # when debugging don't do anything else
        if debug:
            return

        # update version hash if we need to increase version on our own
        # we will increase the previous version, so do it now so the hash
        # is accurate
        if version is None:
            old_version = self.__get_last_experiment_version()
            self.exp_hash = '{}_v{}'.format(self.name, old_version + 1)
            self.version = old_version + 1

        # create a new log file
        self.__init_cache_file_if_needed()

        # when we have a version, load it
        if self.version is not None:

            # when no version and no file, create it
            if not os.path.exists(self.__get_log_name()):
                self.__create_exp_file(self.version)
            else:
                # otherwise load it
                try:
                    self.__load()
                except Exception as e:
                    self.debug = True
        else:
            # if no version given, increase the version to a new exp
            # create the file if not exists
            old_version = self.__get_last_experiment_version()
            self.version = old_version
            self.__create_exp_file(self.version + 1)

        # create a git tag if requested
        if self.create_git_tag:
            desc = description if description is not None else 'no description'
            tag_msg = 'Test tube exp: {} - {}'.format(self.name, desc)
            cmd = 'git tag -a tt_{} -m "{}"'.format(self.exp_hash, tag_msg)
            os.system(cmd)
            print('Test tube created git tag:', 'tt_{}'.format(self.exp_hash))

        # set the tensorboardx log path to the /tf folder in the exp folder
        log_dir = self.get_tensorboardx_path(self.name, self.version)
        # this is a fix for pytorch 1.1 since it does not have this attribute
        for attr, val in [('purge_step', None), ('max_queue', 10),
                          ('flush_secs', 120), ('filename_suffix', '')]:
            if not hasattr(self, attr):
                setattr(self, attr, val)
        super().__init__(log_dir=log_dir, *args, **kwargs)

        # register on exit fx so we always close the writer
        # atexit.register(self.on_exit)

    def get_meta_copy(self):
        """
        Gets a meta-version only copy of this module
        :return:
        """
        return DDPExperiment(self)

    def on_exit(self):
        pass

    def __clean_dir(self):
        files = os.listdir(self.save_dir)

        if self.rank == 0:
            return

        for f in files:
            if str(self.process) in f:
                os.remove(os.path.join(self.save_dir, f))

    def argparse(self, argparser):
        parsed = vars(argparser)
        to_add = {}

        # don't store methods
        for k, v in parsed.items():
            if not callable(v):
                to_add[k] = v

        self.tag(to_add)

    def add_meta_from_hyperopt(self, hypo):
        """
        Transfers meta data about all the params from the
        hyperoptimizer to the log
        :param hypo:
        :return:
        """
        meta = hypo.get_current_trial_meta()
        for tag in meta:
            self.tag(tag)

    # --------------------------------
    # FILE IO UTILS
    # --------------------------------
    def __init_cache_file_if_needed(self):
        """
        Inits a file that we log historical experiments
        :return:
        """
        try:
            exp_cache_file = self.get_data_path(self.name, self.version)
            if not os.path.isdir(exp_cache_file):
                os.makedirs(exp_cache_file, exist_ok=True)
        except Exception as e:
            # file already exists (likely written by another exp. In this case disable the experiment
            self.debug = True

    def __create_exp_file(self, version):
        """
        Recreates the old file with this exp and version
        :param version:
        :return:
        """

        try:
            exp_cache_file = self.get_data_path(self.name, self.version)
            # if no exp, then make it
            path = '{}/meta.experiment'.format(exp_cache_file)
            open(path, 'w').close()
            self.version = version

            # make the directory for the experiment media assets name
            os.makedirs(self.get_media_path(self.name, self.version),
                        exist_ok=True)

            # make the directory for tensorboardx stuff
            os.makedirs(self.get_tensorboardx_path(self.name, self.version),
                        exist_ok=True)
        except Exception as e:
            # file already exists (likely written by another exp. In this case disable the experiment
            self.debug = True

    def __get_last_experiment_version(self):
        try:
            exp_cache_file = os.sep.join(
                self.get_data_path(self.name, self.version).split(os.sep)[:-1])
            return find_last_experiment_version(exp_cache_file)
        except Exception as e:
            return -1

    def __get_log_name(self):
        exp_cache_file = self.get_data_path(self.name, self.version)
        return '{}/meta.experiment'.format(exp_cache_file)

    def tag(self, tag_dict):
        """
        Adds a tag to the experiment.
        Tags are metadata for the exp.

        >> e.tag({"model": "Convnet A"})

        :param key:
        :param val:
        :return:
        """
        if self.debug or self.rank > 0: return

        # parse tags
        for k, v in tag_dict.items():
            self.tags[k] = v

        # save if needed
        if self.autosave == True:
            self.save()

    def log(self, metrics_dict, global_step=None, walltime=None):
        """
        Adds a json dict of metrics.

        >> e.log({"loss": 23, "coeff_a": 0.2})

        :param metrics_dict:
        :tag optional tfx tag
        :return:
        """
        if self.debug or self.rank > 0: return

        # handle tfx metrics
        if global_step is None:
            global_step = len(self.metrics)

        new_metrics_dict = metrics_dict.copy()
        for k, v in metrics_dict.items():
            if isinstance(v, dict):
                self.add_scalars(main_tag=k,
                                 tag_scalar_dict=v,
                                 global_step=global_step,
                                 walltime=walltime)
                tmp_metrics_dict = new_metrics_dict.pop(k)
                new_metrics_dict.update(tmp_metrics_dict)
            else:
                self.add_scalar(tag=k,
                                scalar_value=v,
                                global_step=global_step,
                                walltime=walltime)

        metrics_dict = new_metrics_dict

        # timestamp
        if 'created_at' not in metrics_dict:
            metrics_dict['created_at'] = str(datetime.utcnow())

        self.__convert_numpy_types(metrics_dict)

        self.metrics.append(metrics_dict)

        if self.autosave:
            self.save()

    def __convert_numpy_types(self, metrics_dict):
        for k, v in metrics_dict.items():
            if v.__class__.__name__ == 'float32':
                metrics_dict[k] = float(v)

            if v.__class__.__name__ == 'float64':
                metrics_dict[k] = float(v)

    def save(self):
        """
        Saves current experiment progress
        :return:
        """
        if self.debug or self.rank > 0: return

        # save images and replace the image array with the
        # file name
        self.__save_images(self.metrics)
        metrics_file_path = self.get_data_path(self.name,
                                               self.version) + '/metrics.csv'
        meta_tags_path = self.get_data_path(self.name,
                                            self.version) + '/meta_tags.csv'

        obj = {
            'name': self.name,
            'version': self.version,
            'tags_path': meta_tags_path,
            'metrics_path': metrics_file_path,
            'autosave': self.autosave,
            'description': self.description,
            'created_at': self.created_at,
            'exp_hash': self.exp_hash
        }

        # save the experiment meta file
        with atomic_write(self.__get_log_name()) as tmp_path:
            with open(tmp_path, 'w') as file:
                json.dump(obj, file, ensure_ascii=False)

        # save the metatags file
        df = pd.DataFrame({
            'key': list(self.tags.keys()),
            'value': list(self.tags.values())
        })
        with atomic_write(meta_tags_path) as tmp_path:
            df.to_csv(tmp_path, index=False)

        # save the metrics data
        df = pd.DataFrame(self.metrics)
        with atomic_write(metrics_file_path) as tmp_path:
            df.to_csv(tmp_path, index=False)

        # write new vals to disk
        self.flush()

        # until hparam plugin is fixed, generate hparams as text
        if not self.tag_markdown_saved and len(self.tags) > 0:
            self.tag_markdown_saved = True
            self.add_text('hparams', self.__generate_tfx_meta_log())

    def __generate_tfx_meta_log(self):
        header = f'''###### {self.name}, version {self.version}\n---\n'''
        desc = ''
        if self.description is not None:
            desc = f'''#####*{self.description}*\n'''
        params = f'''##### Hyperparameters\n'''

        row_header = '''parameter|value\n-|-\n'''
        rows = [row_header]
        for k, v in self.tags.items():
            row = f'''{k}|{v}\n'''
            rows.append(row)

        all_rows = [header, desc, params]
        all_rows.extend(rows)
        mkdown_log = ''.join(all_rows)
        return mkdown_log

    def __save_images(self, metrics):
        """
        Save tags that have a png_ prefix (as images)
        and replace the meta tag with the file name
        :param metrics:
        :return:
        """
        # iterate all metrics and find keys with a specific prefix
        for i, metric in enumerate(metrics):
            for k, v in metric.items():
                # if the prefix is a png, save the image and replace the value with the path
                img_extension = None
                img_extension = 'png' if 'png_' in k else img_extension
                img_extension = 'jpg' if 'jpg' in k else img_extension
                img_extension = 'jpeg' if 'jpeg' in k else img_extension

                if img_extension is not None:
                    # determine the file name
                    img_name = '_'.join(k.split('_')[1:])
                    save_path = self.get_media_path(self.name, self.version)
                    save_path = '{}/{}_{}.{}'.format(save_path, img_name, i,
                                                     img_extension)

                    # save image to disk
                    if type(metric[k]) is not str:
                        imwrite(save_path, metric[k])

                    # replace the image in the metric with the file path
                    metric[k] = save_path

    def __load(self):
        # load .experiment file
        with open(self.__get_log_name(), 'r') as file:
            data = json.load(file)
            self.name = data['name']
            self.version = data['version']
            self.autosave = data['autosave']
            self.created_at = data['created_at']
            self.description = data['description']
            self.exp_hash = data['exp_hash']

        # load .tags file
        meta_tags_path = self.get_data_path(self.name,
                                            self.version) + '/meta_tags.csv'
        df = pd.read_csv(meta_tags_path)
        self.tags_list = df.to_dict(orient='records')
        self.tags = {}
        for d in self.tags_list:
            k, v = d['key'], d['value']
            self.tags[k] = v

        # load metrics
        metrics_file_path = self.get_data_path(self.name,
                                               self.version) + '/metrics.csv'
        try:
            df = pd.read_csv(metrics_file_path)
            self.metrics = df.to_dict(orient='records')

            # remove nans
            for metric in self.metrics:
                to_delete = []
                for k, v in metric.items():
                    try:
                        if np.isnan(v):
                            to_delete.append(k)
                    except Exception as e:
                        pass

                for k in to_delete:
                    del metric[k]
        except Exception as e:
            # metrics was empty...
            self.metrics = []

    def get_data_path(self, exp_name, exp_version):
        """
        Returns the path to the local package cache
        :param path:
        :return:
        """
        if self.no_save_dir:
            return os.path.join(_ROOT, 'test_tube_data', exp_name,
                                'version_{}'.format(exp_version))
        else:
            return os.path.join(_ROOT, exp_name,
                                'version_{}'.format(exp_version))

    def get_media_path(self, exp_name, exp_version):
        """
        Returns the path to the local package cache
        :param path:
        :return:
        """
        return os.path.join(self.get_data_path(exp_name, exp_version), 'media')

    def get_tensorboardx_path(self, exp_name, exp_version):
        """
        Returns the path to the local package cache
        :param path:
        :return:
        """
        return os.path.join(self.get_data_path(exp_name, exp_version), 'tf')

    def get_tensorboardx_scalars_path(self, exp_name, exp_version):
        """
        Returns the path to the local package cache
        :param path:
        :return:
        """
        tfx_path = self.get_tensorboardx_path(exp_name, exp_version)
        return os.path.join(tfx_path, 'scalars.json')

    # ----------------------------
    # OVERWRITES
    # ----------------------------
    def _get_file_writer(self):
        """Returns the default FileWriter instance. Recreates it if closed."""
        if self.rank > 0:
            return TTDummyFileWriter()

        if self.all_writers is None or self.file_writer is None:
            if self.purge_step is not None:
                most_recent_step = self.purge_step
                self.file_writer = FileWriter(self.log_dir, self.max_queue,
                                              self.flush_secs,
                                              self.filename_suffix)
                self.file_writer.debug = self.debug
                self.file_writer.rank = self.rank

                self.file_writer.add_event(
                    Event(step=most_recent_step, file_version='brain.Event:2'))
                self.file_writer.add_event(
                    Event(step=most_recent_step,
                          session_log=SessionLog(status=SessionLog.START)))
            else:
                self.file_writer = FileWriter(self.log_dir, self.max_queue,
                                              self.flush_secs,
                                              self.filename_suffix)
            self.all_writers = {
                self.file_writer.get_logdir(): self.file_writer
            }
        return self.file_writer

    def __str__(self):
        return 'Exp: {}, v: {}'.format(self.name, self.version)

    def __hash__(self):
        return 'Exp: {}, v: {}'.format(self.name, self.version)

    def flush(self):
        if self.rank > 0:
            return

        if self.all_writers is None:
            return  # ignore double close

        for writer in self.all_writers.values():
            writer.flush()