示例#1
0
    def trainIters(self, n_iters, model_file_path=None):
        iter, running_avg_loss = self.setup_train(model_file_path)
        start = time.time()
        min_val_loss = np.inf
        while iter < n_iters:
            batch = self.batcher.next_batch()
            loss = self.train_one_batch(batch)

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     self.summary_writer, iter)
            iter += 1

            if iter % config.print_interval == 0:
                tf.logging.info(
                    'steps %d, seconds for %d batch: %.2f , loss: %f, min_val_loss: %f'
                    % (iter, config.print_interval, time.time() - start, loss,
                       min_val_loss))
                start = time.time()
            if iter % config.model_save_iters == 0:
                self.summary_writer.flush()
                model_save_path = self.save_model(running_avg_loss,
                                                  iter,
                                                  mode='train')
                tf.logging.info('Evaluate the model %s at validation set....' %
                                model_save_path)
                evl_model = Evaluate(model_save_path)
                val_avg_loss = evl_model.run_eval()
                if val_avg_loss < min_val_loss:
                    min_val_loss = val_avg_loss
                    best_model_save_path = self.save_model(running_avg_loss,
                                                           iter,
                                                           mode='eval')
                    tf.logging.info('Save best model at %s' %
                                    best_model_save_path)
示例#2
0
def validate(epoch_idx, abstracts1):
    model.load_state_dict(torch.load(args.save))
    print("model restored")
    test_loader = DataLoader(abstracts1, config.batch_size)
    eval_f = Evaluate()
    num_exams = 3
    predictor = Predictor(model, abstracts1.vectorizer)
    print("Start Evaluating")
    print("Test Data: ", len(abstracts1))
    cand, ref = predictor.preeval_batch(test_loader, len(abstracts1),
                                        num_exams)
    scores = []
    final = []
    fields = ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4", "METEOR", "ROUGE_L"]
    for i in range(6):
        scores.append([])
    for i in range(num_exams):
        print("No.", i)
        final_scores = eval_f.evaluate(live=True, cand=cand[i], ref=ref)
        for j in range(6):
            scores[j].append(final_scores[fields[j]])
    with open('figure.pkl', 'wb') as f:
        pickle.dump((fields, scores), f)
    # Start writing ...
    f_out_name = cwd + config.relative_score_path
    f_out_name = f_out_name % epoch_idx
    f_out = open(f_out_name, 'w')
    for j in range(6):
        f_out.write(fields[j] + ':  ')
        f_out.write(str(final_scores[fields[j]]) + '\n')
        final.append(final_scores[fields[j]])
    f_out.close()
    print("FFFF = ", final_scores)
    return sum(final) / float(len(final))
示例#3
0
def main():
    """

    Run train and predict for the various Lorenz map prediction models with user
    provided arguments. Assets are saved in the 'assets' folder in the project directory.

    Models can be Conditional Wavenet-inspired (cw), Unconditional Wavenet-inspired (w),

    Targets to predict are x (ts=0), y(ts=1), or z(ts=2) Lorenz trajectories.
    """

    argparser = ArgParser()
    options = argparser.parse_args()
    data_generator = LorenzMapData(options)
    train_data, test_data = data_generator.generate_train_test_sets()

    # Train
    trainer = Train(options)
    train_iter = DIterators(options).build_iterator(train_data, for_train=True)
    trainer.train(train_iter)

    # Predict on test set and evaluate
    predictor = Predict(options)
    predict_iter = DIterators(options).build_iterator(test_data, for_train=False)
    predictor.predict(predict_iter)

    # Evaluate performance on test set
    evaluator = Evaluate(options)
    evaluator()
示例#4
0
    def run(self, n_iters, model_path=None):
        iter, running_avg_loss = self.setup_train(model_path)
        start = time.time()
        interval = 100
        prev_eval_loss = float("inf")
        while (time.time() - start) / 3600 <= 11.0:  #iter < n_iters:
            batch = self.batcher.next_batch()
            loss, cove_loss = self.train_one_batch(batch)

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     self.summary_writer, iter)
            iter += 1

            if iter % interval == 0:
                self.summary_writer.flush()
                print('step: %d, second: %.2f , loss: %f, cover_loss: %f' %
                      (iter, time.time() - start, loss, cove_loss))
                start = time.time()
            if iter % 20000 == 0:
                self.save_model(running_avg_loss, iter, 'model_temp')
                eval_loss = Evaluate(os.path.join(self.model_dir,
                                                  'model_temp')).run()
                if eval_loss < prev_eval_loss:
                    print(
                        f"eval loss for iteration: {iter} is {eval_loss}, previous best eval loss = {prev_eval_loss}, saving checkpoint..."
                    )
                    prev_eval_loss = eval_loss
                    self.save_model(running_avg_loss, iter)
                else:
                    print(
                        f"eval loss for iteration: {iter}, previous best eval loss = {prev_eval_loss}, no improvement, skipping..."
                    )
示例#5
0
    def __init__(self, args):

        self.args = args

        # Model Configuration to execute.
        self.config = configurations.init(args.dataset)[args.conf]
        if args.local_rank == 0:
            print("Config is", args.conf)

        # Model's checkpoint filename.
        v = vars(self.args)
        v['save'] = "models/" + self.config.experiment_name + '.pkl'

        # Set the random seed manually for reproducibility.
        self.seed()

        # Evaluation API for calculating the BLEU, METEOR and ROUGE scores
        self.validation_eval = Evaluate()

        # Training and Validation datasets
        self.training_abstracts, self.validation_abstracts = self.load_datasets(
        )

        # THE model!
        self.model = self.initialize_model()
示例#6
0
def evaluate(seed, constant, experience_dir, valid_path, batch_size):
    if not os.path.exists(experience_dir):
        os.mkdir(experience_dir)

    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    generator = nn.DataParallel(
        Generator(constant, in_channels=1, out_channels=3)).cuda()
    i2v = nn.DataParallel(Illustration2Vec()).cuda()

    if os.path.isfile(os.path.join(experience_dir, 'model/generator.pth')):
        checkG = torch.load(os.path.join(experience_dir,
                                         'model/generator.pth'))
        generator.load_state_dict(checkG['generator'])
    else:
        print('[ERROR] Need generator checkpoint!')
        exit(1)

    generator.eval()

    img_size = (512, 512)
    valloader = CreateValidLoader(valid_path, batch_size, img_size)
    config = {
        'target_npz': './res/model/fid_stats_color.npz',
        'corps': 2,
        'image_size': img_size[0]
    }
    evaluate = Evaluate(generator, i2v, valloader,
                        namedtuple('Config', config.keys())(*config.values()))

    fid, fid_var = evaluate()
    print(
        f'\n===================\nFID = {fid} +- {fid_var}\n===================\n'
    )

    with open(os.path.join(experience_dir, 'evaluate_fid.csv'), 'w') as f:
        f.write(f'evaluate;{fid};{fid_var}\n')
示例#7
0
             print(outputs[i])
         print('-' * 120)
         count += 1
 elif args.mode == 3:
     cwd = os.getcwd()
     test_data_path = cwd + config.relative_test_path
     test_abstracts = headline2abstractdataset(
         test_data_path,
         training_abstracts.vectorizer,
         args.cuda,
         max_len=1000,
         use_topics=config.use_topics,
         use_structure_info=config.use_labels)
     load_checkpoint()
     test_loader = DataLoader(test_abstracts, config.batch_size)
     eval_f = Evaluate()
     num_exams = 3
     predictor = Predictor(model,
                           training_abstracts.vectorizer,
                           use_cuda=args.cuda)
     print("Start Evaluating")
     print("Test Data: ", len(validation_abstracts))
     cand, ref, org = predictor.preeval_batch(test_loader,
                                              len(validation_abstracts),
                                              num_exams,
                                              use_topics=config.use_topics,
                                              use_labels=config.use_labels)
     scores = []
     fields = ["Bleu_4", "METEOR", "ROUGE_L"]
     for i in range(3):
         scores.append([])
示例#8
0
    def trainIters(self, n_iters, model_file_path=None):
        iter, running_avg_loss = self.setup_train(model_file_path)
        min_val_loss = np.inf

        alpha = config.alpha
        beta = config.beta
        k1 = config.k1
        k2 = config.k2
        delay = 0
        while iter < n_iters:
            if config.mode == 'RL':
                alpha = 0.
                beta = 0.
            elif config.mode == 'GTI':
                alpha = 1.
                beta = 0.
            elif config.mode == 'SO':
                alpha = 1.
                beta = k2 / (k2 + np.exp((iter - delay) / k2))
            elif config.mode == 'SIO':
                alpha *= k1
                if alpha < 0.01:
                    beta = k2 / (k2 + np.exp((iter - delay) / k2))
                else:
                    beta = 1.
                    delay += 1
            elif config.mode == 'DAGGER':
                alpha *= k1
                beta = 1.
            elif config.mode == 'DAGGER*':
                alpha = config.alpha
                beta = 1.
            else:
                alpha = 1.
                beta = 1.

            batch = self.batcher.next_batch()
            loss, avg_reward = self.train_one_batch(batch, alpha, beta)
            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     self.summary_writer, iter)
            iter += 1

            if iter % config.print_interval == 0:
                print('steps %d, current_loss: %f, avg_reward: %f' %
                      (iter, loss, avg_reward))

            if iter % config.save_model_iter == 0:
                model_file_path = self.save_model(running_avg_loss,
                                                  iter,
                                                  mode='train')
                evl_model = Evaluate(model_file_path)
                val_avg_loss = evl_model.run_eval()
                if val_avg_loss < min_val_loss:
                    min_val_loss = val_avg_loss
                    best_model_file_path = self.save_model(running_avg_loss,
                                                           iter,
                                                           mode='eval')
                    print('Save best model at %s' % best_model_file_path)
                print('steps %d, train_loss: %f, val_loss: %f' %
                      (iter, loss, val_avg_loss))
                # write val_loss into tensorboard
                loss_sum = tf.compat.v1.Summary()
                loss_sum.value.add(tag='val_avg_loss',
                                   simple_value=val_avg_loss)
                self.summary_writer.add_summary(loss_sum, global_step=iter)
                self.summary_writer.flush()
示例#9
0
from eval import Evaluate, ControlPlots
from train_xgb import StackedModel
from train import ModelCNN

data_test = "/mnt/lustre/helios-home/kubumiro/dl/TensorFlow-GPU/Masters/Masters/Data_split/data_test/data_test.h5"
data_valid = "/mnt/lustre/helios-home/kubumiro/dl/TensorFlow-GPU/Masters/Masters/Data_split/data_valid/data_valid.h5"
data_train = "/mnt/lustre/helios-home/kubumiro/dl/TensorFlow-GPU/Masters/Masters/Data_split/data_train/data_train.h5"

# model names ... "CNN", "ResNet", "ShortCNN" and "XGB","AdaBoost", GradientBoosting

#model = ModelCNN(name = "ShortCNN")
#model.Train()
#model.Predict(data_valid)

#Evaluate("CVN",data_test, histo = True)

model = StackedModel(name_cnn="ResNet", name_boost="XGB")
#model.Train(known_maps = True)
model.Predict(data_train)
Evaluate("XGB_ResNet", data_train, histo=True)

model = StackedModel(name_cnn="ResNet", name_boost="RandomForest")
model.Predict(data_train)
Evaluate("RandomForest_ResNet", data_train, histo=True)

model = StackedModel(name_cnn="ResNet", name_boost="AdaBoost")
model.Predict(data_train)
Evaluate("AdaBoost_ResNet", data_train, histo=True)
示例#10
0
    def train(self, n_iters, init_model_path=None):
        iter, avg_loss = self.setup_train(init_model_path)
        start = time.time()
        cnt = 0
        best_model_path = None
        min_eval_loss = float('inf')
        while iter < n_iters:
            s = config.forcing_ratio
            k = config.decay_to_0_iter
            x = iter
            nere_zero = 0.0001
            if config.forcing_decay_type:
                if x >= config.decay_to_0_iter:
                    forcing_ratio = 0
                elif config.forcing_decay_type == 'linear':
                    forcing_ratio = s * (k - x) / k
                elif config.forcing_decay_type == 'exp':
                    p = pow(nere_zero, 1 / k)
                    forcing_ratio = s * (p**x)
                elif config.forcing_decay_type == 'sig':
                    r = math.log((1 / nere_zero) - 1) / k
                    forcing_ratio = s / (1 + pow(math.e, r * (x - k / 2)))
                else:
                    raise ValueError('Unrecognized forcing_decay_type: ' +
                                     config.forcing_decay_type)
            else:
                forcing_ratio = config.forcing_ratio
            batch = self.batcher.next_batch()
            loss = self.train_one_batch(batch, forcing_ratio=forcing_ratio)
            model_path = os.path.join(self.checkpoint_dir,
                                      'model_step_%d' % (iter + 1))
            avg_loss = calc_avg_loss(loss, avg_loss)

            if (iter + 1) % config.print_interval == 0:
                with self.train_summary_writer.as_default():
                    tf.summary.scalar(name='loss', data=loss, step=iter)
                self.train_summary_writer.flush()
                logger.info('steps %d, took %.2f seconds, train avg loss: %f' %
                            (iter + 1, time.time() - start, avg_loss))
                start = time.time()
            if config.eval_interval is not None and (
                    iter + 1) % config.eval_interval == 0:
                start = time.time()
                logger.info("Start Evaluation on model %s" % model_path)
                eval_processor = Evaluate(self.model, self.vocab)
                eval_loss = eval_processor.run_eval()
                logger.info(
                    "Evaluation finished, took %.2f seconds, eval loss: %f" %
                    (time.time() - start, eval_loss))
                with self.eval_summary_writer.as_default():
                    tf.summary.scalar(name='eval_loss',
                                      data=eval_loss,
                                      step=iter)
                self.eval_summary_writer.flush()
                if eval_loss < min_eval_loss:
                    logger.info(
                        "This is the best model so far, saving it to disk.")
                    min_eval_loss = eval_loss
                    best_model_path = model_path
                    self.save_model(model_path, eval_loss, iter)
                    cnt = 0
                else:
                    cnt += 1
                    if cnt > config.patience:
                        logger.info(
                            "Eval loss doesn't drop for %d straight times, early stopping.\n"
                            "Best model: %s (Eval loss %f: )" %
                            (config.patience, best_model_path, min_eval_loss))
                        break
                start = time.time()
            elif (iter + 1) % config.save_interval == 0:
                self.save_model(model_path, avg_loss, iter)
            iter += 1
        else:
            logger.info(
                "Training finished, best model: %s, with train loss %f: " %
                (best_model_path, min_eval_loss))
示例#11
0
    def trainIters(self, n_iters, model_file_path=None):
        if config.mode not in [
                "MLE", "RL", "GTI", "SO", "SIO", "DAGGER", "DAGGER*"
        ]:
            print("\nTRAINING MODE ERROR\n")
            raise ValueError
        # log file path
        log_path = os.path.join(config.log_root, 'log')
        log = open(log_path, 'w')
        print_log("==============================", file=log)
        iter, running_avg_loss = self.setup_train(
            model_file_path,
            emb_v_path=config.emb_v_path,
            emb_list_path=config.vocab_path,
            vocab=self.vocab,
            log=log)
        min_val_loss = np.inf

        alpha = config.alpha
        beta = config.beta
        k1 = config.k1
        k2 = config.k2
        delay = iter  # set to 0 in the original code (wyu-du)

        print("\nLog root is %s" % config.log_root)
        print_log("Train mode is %s" % config.mode, file=log)
        print_log("k1: %s, k2: %s" % (config.k1, config.k2), file=log)
        print_log("==============================", file=log)

        cur_time = time.time()
        while iter < n_iters:
            if config.mode == 'RL':
                alpha = 0.
                beta = 0.
            elif config.mode == 'GTI':
                alpha = 1.
                beta = 0.
            elif config.mode == 'SO':
                alpha = 1.
                beta = k2 / (k2 + np.exp((iter - delay) / k2))
            elif config.mode == 'SIO':
                alpha *= k1
                if alpha < 0.01:
                    beta = k2 / (k2 + np.exp((iter - delay) / k2))
                else:
                    beta = 1.
                    delay += 1
            elif config.mode == 'DAGGER':
                alpha *= k1
                beta = 1.
            elif config.mode == 'DAGGER*':
                alpha = config.alpha
                beta = 1.
            else:
                alpha = 1.
                beta = 1.

            batch = self.batcher.next_batch()
            loss, avg_reward = self.train_one_batch(batch, alpha, beta)
            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     self.summary_writer, iter)
            iter += 1

            if iter % config.print_interval == 0:
                print_log('steps %d, current_loss: %f, avg_reward: %f, alpha: %f, beta: %f, delay: %d' % \
                            (iter, loss, avg_reward, alpha, beta, delay), file=log)

            if iter % config.save_model_iter == 0:
                model_file_path = self.save_model(running_avg_loss,
                                                  iter,
                                                  mode='train')
                evl_model = Evaluate(model_file_path)
                val_avg_loss = evl_model.run_eval()
                if val_avg_loss < min_val_loss:
                    min_val_loss = val_avg_loss
                    best_model_file_path = self.save_model(running_avg_loss,
                                                           iter,
                                                           mode='eval')
                    print_log('Save best model at %s' % best_model_file_path,
                              file=log)
                print_log('steps %d, train_loss: %f, val_loss: %f, time: %ds' % \
                                        (iter, loss, val_avg_loss, time.time()-cur_time), file=log)
                # write val_loss into tensorboard
                loss_sum = tf.compat.v1.Summary()
                loss_sum.value.add(tag='val_avg_loss',
                                   simple_value=val_avg_loss)
                self.summary_writer.add_summary(loss_sum, global_step=iter)
                self.summary_writer.flush()
                cur_time = time.time()

        log.close()
示例#12
0
     for i in range(num_exams):
         out_name = f_out_name % i
         f_out = open(out_name, 'w')
         for j in range(len(title)):
             f_out.write(title[j] + '\n' + outputs[i][j] + '\n\n')
             if j % 100 == 0:
                 print("Percentages:  %.4f" % (j/float(len(abstracts))))
         f_out.close()
     f_out.close()
 elif args.mode == 3:
     model.load_state_dict(torch.load(args.save))
     print("model restored")
     dev_data_path = cwd + config.relative_dev_path
     abstracts = headline2abstractdataset(dev_data_path, abstracts.vectorizer, args.cuda, max_len=1000)
     test_loader = DataLoader(abstracts, config.batch_size)
     eval_f = Evaluate()
     num_exams = 8
     predictor = Predictor(model, abstracts.vectorizer)
     print("Start Evaluating")
     print("Test Data: ", len(abstracts))
     cand, ref = predictor.preeval_batch(test_loader, len(abstracts), num_exams)
     scores = []
     fields = ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4", "METEOR", "ROUGE_L"]
     for i in range(6):
         scores.append([])
     for i in range(num_exams):
         print("No.", i)
         final_scores = eval_f.evaluate(live=True, cand=cand[i], ref=ref)
         for j in range(6):
             scores[j].append(final_scores[fields[j]])
     with open('figure.pkl', 'wb') as f:
示例#13
0
def train(double, seed, constant, experience_dir, train_path, valid_path,
          batch_size, epochs, drift, adwd, method):
    if not os.path.exists(experience_dir):
        os.mkdir(experience_dir)
    if not os.path.exists(os.path.join(experience_dir, 'model')):
        os.mkdir(os.path.join(experience_dir, 'model'))
    if not os.path.exists(os.path.join(experience_dir, 'vizu')):
        os.mkdir(os.path.join(experience_dir, 'vizu'))

    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    generator = nn.DataParallel(
        Generator(constant, in_channels=1, out_channels=3)).cuda()
    discriminator = nn.DataParallel(Discriminator(constant)).cuda()
    features = nn.DataParallel(Features()).cuda()
    i2v = nn.DataParallel(Illustration2Vec()).cuda()
    sketcher = nn.DataParallel(
        Generator(constant, in_channels=3,
                  out_channels=1)).cuda() if double else None

    optimizerG    = optim.Adam(
        list(generator.parameters()) + list(sketcher.parameters()) \
        if double \
        else generator.parameters(),
        lr=1e-4, betas=(0.5, 0.9)
    )
    optimizerD = optim.Adam(discriminator.parameters(),
                            lr=1e-4,
                            betas=(0.5, 0.9))

    lr_schedulerG = StepLRScheduler(optimizerG, [125000], [0.1], 1e-4, 0, 0)
    lr_schedulerD = StepLRScheduler(optimizerD, [125000], [0.1], 1e-4, 0, 0)

    if os.path.isfile(os.path.join(experience_dir, 'model/generator.pth')):
        checkG = torch.load(os.path.join(experience_dir,
                                         'model/generator.pth'))
        generator.load_state_dict(checkG['generator'])
        if double: sketcher.load_state_dict(checkG['sketcher'])
        optimizerG.load_state_dict(checkG['optimizer'])

    if os.path.isfile(os.path.join(experience_dir, 'model/discriminator.pth')):
        checkD = torch.load(
            os.path.join(experience_dir, 'model/discriminator.pth'))
        discriminator.load_state_dict(checkD['discriminator'])
        optimizerD.load_state_dict(checkD['optimizer'])

    for param in features.parameters():
        param.requires_grad = False
    mse = nn.MSELoss().cuda()
    grad_penalty = GradPenalty(10).cuda()

    img_size = (512, 512)
    mask_gen = Hint((img_size[0] // 4, img_size[1] // 4), 120, (1, 4), 5,
                    (10, 10))
    dataloader = CreateTrainLoader(train_path,
                                   batch_size,
                                   mask_gen,
                                   img_size,
                                   method=method)
    iterator = iter(dataloader)

    valloader = CreateValidLoader(valid_path, batch_size, img_size)
    config = {
        'target_npz': './res/model/fid_stats_color.npz',
        'corps': 2,
        'image_size': img_size[0]
    }
    evaluate = Evaluate(generator, i2v, valloader,
                        namedtuple('Config', config.keys())(*config.values()))

    start_epoch = 0
    if os.path.isfile(os.path.join(experience_dir, 'model/generator.pth')):
        checkG = torch.load(os.path.join(experience_dir,
                                         'model/generator.pth'))
        start_epoch = checkG['epoch'] + 1

    for epoch in range(start_epoch, epochs):
        batch_id = -1

        pbar = tqdm(total=len(dataloader),
                    desc=f'Epoch [{(epoch + 1):06d}/{epochs}]')
        id = 0
        dgp = 0.
        gc = 0.
        s = 0. if double else None

        while batch_id < len(dataloader):
            lr_schedulerG.step(epoch)
            lr_schedulerD.step(epoch)

            generator.train()
            discriminator.train()
            if double: sketcher.train()

            # =============
            # DISCRIMINATOR
            # =============
            for p in discriminator.parameters():
                p.requires_grad = True
            for p in generator.parameters():
                p.requires_grad = False
            optimizerD.zero_grad()
            optimizerG.zero_grad()

            batch_id += 1
            try:
                colored, sketch, hint = iterator.next()
            except StopIteration:
                iterator = iter(dataloader)
                colored, sketch, hint = iterator.next()

            real_colored = colored.cuda()
            real_sketch = sketch.cuda()
            hint = hint.cuda()

            with torch.no_grad():
                feat_sketch = i2v(real_sketch).detach()
                fake_colored = generator(real_sketch, hint,
                                         feat_sketch).detach()

            errD_fake = discriminator(fake_colored,
                                      feat_sketch).mean(0).view(1)
            errD_fake.backward(retain_graph=True)

            errD_real = discriminator(real_colored,
                                      feat_sketch).mean(0).view(1)
            errD = errD_real - errD_fake

            errD_realer = -1 * errD_real + errD_real.pow(2) * drift
            errD_realer.backward(retain_graph=True)

            gp = grad_penalty(discriminator, real_colored, fake_colored,
                              feat_sketch)
            gp.backward()

            optimizerD.step()
            pbar.update(1)

            dgp += errD_realer.item() + gp.item()

            # =============
            # GENERATOR
            # =============
            for p in generator.parameters():
                p.requires_grad = True
            for p in discriminator.parameters():
                p.requires_grad = False
            optimizerD.zero_grad()
            optimizerG.zero_grad()

            batch_id += 1
            try:
                colored, sketch, hint = iterator.next()
            except StopIteration:
                iterator = iter(dataloader)
                colored, sketch, hint = iterator.next()

            real_colored = colored.cuda()
            real_sketch = sketch.cuda()
            hint = hint.cuda()

            with torch.no_grad():
                feat_sketch = i2v(real_sketch).detach()

            fake_colored = generator(real_sketch, hint, feat_sketch)

            errD = discriminator(fake_colored, feat_sketch)
            errG = -1 * errD.mean() * adwd
            errG.backward(retain_graph=True)

            feat1 = features(fake_colored)
            with torch.no_grad():
                feat2 = features(real_colored)

            contentLoss = mse(feat1, feat2)
            contentLoss.backward()

            optimizerG.step()
            pbar.update(1)

            gc += errG.item() + contentLoss.item()

            # =============
            # SKETCHER
            # =============
            if double:
                for p in generator.parameters():
                    p.requires_grad = True
                for p in discriminator.parameters():
                    p.requires_grad = False
                optimizerD.zero_grad()
                optimizerG.zero_grad()

                batch_id += 1
                try:
                    colored, sketch, hint = iterator.next()
                except StopIteration:
                    iterator = iter(dataloader)
                    colored, sketch, hint = iterator.next()

                real_colored = colored.cuda()
                real_sketch = sketch.cuda()
                hint = hint.cuda()

                with torch.no_grad():
                    feat_sketch = i2v(real_sketch).detach()

                fake_colored = generator(real_sketch, hint, feat_sketch)
                fake_sketch = sketcher(fake_colored, hint, feat_sketch)
                errS = mse(fake_sketch, real_sketch)
                errS.backward()

                optimizerG.step()
                pbar.update(1)

                s += errS.item()

            pbar.set_postfix(
                **{
                    'dgp': dgp / (id + 1),
                    'gc': gc / (id + 1),
                    's': s / (id + 1) if double else None
                })

            # =============
            # PLOT
            # =============
            generator.eval()
            discriminator.eval()

            if id % 20 == 0:
                tensors2vizu(
                    img_size,
                    os.path.join(experience_dir,
                                 f'vizu/out_{epoch}_{id}_{batch_id}.jpg'), **{
                                     'strokes': hint[:, :3, ...],
                                     'col': real_colored,
                                     'fcol': fake_colored,
                                     'sketch': real_sketch,
                                     'fsketch': fake_sketch if double else None
                                 })

            id += 1

        pbar.close()

        torch.save(
            {
                'generator': generator.state_dict(),
                'sketcher': sketcher.state_dict() if double else None,
                'optimizer': optimizerG.state_dict(),
                'double': double,
                'epoch': epoch
            }, os.path.join(experience_dir, 'model/generator.pth'))

        torch.save(
            {
                'discriminator': discriminator.state_dict(),
                'optimizer': optimizerD.state_dict(),
                'epoch': epoch
            }, os.path.join(experience_dir, 'model/discriminator.pth'))

        if epoch % 20 == 0:
            fid, fid_var = evaluate()
            print(
                f'\n===================\nFID = {fid} +- {fid_var}\n===================\n'
            )

            with open(os.path.join(experience_dir, 'fid.csv'), 'a+') as f:
                f.write(f'{epoch};{fid};{fid_var}\n')