예제 #1
0
파일: train.py 프로젝트: ReaFly/ACSNet
def valid(model, valid_dataloader, total_batch):

    model.eval()

    # Metrics_logger initialization
    metrics = Metrics(['recall', 'specificity', 'precision', 'F1', 'F2',
                       'ACC_overall', 'IoU_poly', 'IoU_bg', 'IoU_mean'])

    with torch.no_grad():
        bar = tqdm(enumerate(valid_dataloader), total=total_batch)
        for i, data in bar:
            img, gt = data['image'], data['label']

            if opt.use_gpu:
                img = img.cuda()
                gt = gt.cuda()

            output = model(img)
            _recall, _specificity, _precision, _F1, _F2, \
            _ACC_overall, _IoU_poly, _IoU_bg, _IoU_mean = evaluate(output, gt)

            metrics.update(recall= _recall, specificity= _specificity, precision= _precision, 
                            F1= _F1, F2= _F2, ACC_overall= _ACC_overall, IoU_poly= _IoU_poly, 
                            IoU_bg= _IoU_bg, IoU_mean= _IoU_mean
                        )

    metrics_result = metrics.mean(total_batch)

    return metrics_result
예제 #2
0
def precision_per_class(preds, labels, mask):
    import heapq
    mask = mask.astype(int)
    labels = labels.astype(int)
    val_indexes = np.where(mask == 1)[0]
    pred_true_labels = {}
    for i in val_indexes:
        pred_probs_i = preds[i]
        true_raw_i = labels[i]

        pred_label_i = heapq.nlargest(np.sum(true_raw_i),
                                      range(len(pred_probs_i)),
                                      pred_probs_i.take)
        true_label_i = np.where(true_raw_i == 1)[0]
        pred_true_labels[i] = (pred_label_i, true_label_i)
    accuracy_per_classes = metrics.evaluate(pred_true_labels)

    from sklearn.metrics import roc_curve, auc
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    test_y = labels[val_indexes]
    test_pred = preds[val_indexes]
    fpr["micro"], tpr["micro"], _ = roc_curve(test_y.ravel(),
                                              test_pred.ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
    #     print('micro_auc=',roc_auc["micro"])

    return accuracy_per_classes, roc_auc["micro"]
예제 #3
0
def train_single_epoch(config, model, dataloader, criterion, optimizer,
                       scheduler, epoch):
    batch_time = AverageMeter()
    losses = AverageMeter()

    reses = AverageMeter()
    accs = AverageMeter()
    senses = AverageMeter()
    specs = AverageMeter()

    model.train()

    end = time.time()
    for i, (images, labels) in enumerate(dataloader):
        optimizer.zero_grad()

        images = images.cuda()
        labels = labels.cuda()
        n_data = images.shape[0]

        logits = model(images)

        labels = labels.squeeze(1)

        if config.LOSS.LABEL_SMOOTHING:
            smoother = LabelSmoother()
            loss = criterion(logits, smoother(labels))
        else:
            loss = criterion(logits, labels)

        losses.update(loss.item(), n_data)

        loss.backward()
        optimizer.step()

        if config.SCHEDULER.NAME == 'one_cycle_lr':
            scheduler.step()
        preds = logits.argmax(dim=1)
        res, _, _, ttl_acc, ttl_sens, ttl_spec = evaluate(preds, labels)

        reses.update(res, n_data)
        accs.update(ttl_acc, n_data)
        senses.update(ttl_sens, n_data)
        specs.update(ttl_spec, n_data)

        batch_time.update(time.time() - end)
        end = time.time()

        if i % config.PRINT_EVERY == 0:
            print(
                '[%d/%d][%d/%d] time: %.2f, loss: %.4f, res: %.4f, acc: %.4f, sens: %.4f, spec: %.4f  [lr: %.6f]'
                % (epoch, config.TRAIN.NUM_EPOCHS, i, len(dataloader),
                   batch_time.sum, loss.item(), res, ttl_acc, ttl_sens,
                   ttl_spec, optimizer.param_groups[0]['lr']), )

        del images, labels, logits
        torch.cuda.empty_cache()

    return (losses.avg, reses.avg, accs.avg, senses.avg, specs.avg)
예제 #4
0
def test(exp_name):
    print('loading data......')
    test_data = getattr(datasets, opt.dataset)(opt.root,
                                               opt.test_data_dir,
                                               mode='test',
                                               size=opt.testsize)
    test_dataloader = DataLoader(test_data,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=opt.num_workers)
    total_batch = int(len(test_data) / 1)
    model, _, _ = generate_model(opt)

    model.eval()

    # metrics_logger initialization
    metrics = Metrics([
        'recall', 'specificity', 'precision', 'F1', 'F2', 'ACC_overall',
        'IoU_poly', 'IoU_bg', 'IoU_mean'
    ])

    logger = get_logger('./results/' + exp_name + '.log')
    with torch.no_grad():
        for i, data in enumerate(test_dataloader):
            img, gt = data['image'], data['label']

            if opt.use_gpu:
                img = img.cuda()
                gt = gt.cuda()

            output = model(img)
            _recall, _specificity, _precision, _F1, _F2, \
            _ACC_overall, _IoU_poly, _IoU_bg, _IoU_mean = evaluate(output, gt)

            metrics.update(recall=_recall,
                           specificity=_specificity,
                           precision=_precision,
                           F1=_F1,
                           F2=_F2,
                           ACC_overall=_ACC_overall,
                           IoU_poly=_IoU_poly,
                           IoU_bg=_IoU_bg,
                           IoU_mean=_IoU_mean)

    metrics_result = metrics.mean(total_batch)

    print("Test Result:")
    logger.info(
        'recall: %.4f, specificity: %.4f, precision: %.4f, F1: %.4f, F2: %.4f, '
        'ACC_overall: %.4f, IoU_poly: %.4f, IoU_bg: %.4f, IoU_mean: %.4f' %
        (metrics_result['recall'], metrics_result['specificity'],
         metrics_result['precision'], metrics_result['F1'],
         metrics_result['F2'], metrics_result['ACC_overall'],
         metrics_result['IoU_poly'], metrics_result['IoU_bg'],
         metrics_result['IoU_mean']))
예제 #5
0
def eval():


    args = DefaultConfig()
    print('#' * 20, 'Start Evaluation', '#' * 20)
    for dataset in tqdm.tqdm(args.testdataset, total=len(args.testdataset), position=0,
                             bar_format='{desc:<30}{percentage:3.0f}%|{bar:50}{r_bar}'):
        pred_path = 'E:\dataset\data\TestDataset/{}/output/'.format(dataset)
        gt_path = 'E:\dataset\data\TestDataset/{}/masks/'.format(dataset)
        preds = os.listdir(pred_path)
        gts = os.listdir(gt_path)
        total_batch =  len(preds)
        # metrics_logger initialization
        metrics = Metrics(['recall', 'specificity', 'precision', 'F1', 'F2',
                           'ACC_overall', 'IoU_poly', 'IoU_bg', 'IoU_mean', 'Dice'])

        for i, sample in tqdm.tqdm(enumerate(zip(preds, gts)), desc=dataset + ' - Evaluation', total=len(preds),
                                   position=1, leave=False, bar_format='{desc:<30}{percentage:3.0f}%|{bar:50}{r_bar}'):
            pred, gt = sample
            assert os.path.splitext(pred)[0] == os.path.splitext(gt)[0]

            pred_mask = np.array(Image.open(os.path.join(pred_path, pred)))
            gt_mask = np.array(Image.open(os.path.join(gt_path, gt)))
            if len(pred_mask.shape) != 2:
                pred_mask = pred_mask[:, :, 0]
            if len(gt_mask.shape) != 2:
                gt_mask = gt_mask[:, :, 0]

            assert pred_mask.shape == gt_mask.shape
            gt_mask = gt_mask.astype(np.float64) / 255
            pred_mask = pred_mask.astype(np.float64) / 255

            gt_mask = torch.from_numpy(gt_mask)
            pred_mask =  torch.from_numpy(pred_mask)
            _recall, _specificity, _precision, _F1, _F2, \
            _ACC_overall, _IoU_poly, _IoU_bg, _IoU_mean, _Dice = evaluate(pred_mask, gt_mask, 0.5)

            metrics.update(recall=_recall, specificity=_specificity, precision=_precision,
                           F1=_F1, F2=_F2, ACC_overall=_ACC_overall, IoU_poly=_IoU_poly,
                           IoU_bg=_IoU_bg, IoU_mean=_IoU_mean, Dice=_Dice
                           )
        metrics_result = metrics.mean(total_batch)
        print("Test Result:")
        print('recall: %.4f, specificity: %.4f, precision: %.4f, F1: %.4f, F2: %.4f, '
              'ACC_overall: %.4f, IoU_poly: %.4f, IoU_bg: %.4f, IoU_mean: %.4f, Dice:%.4f'
              % (metrics_result['recall'], metrics_result['specificity'], metrics_result['precision'],
                 metrics_result['F1'], metrics_result['F2'], metrics_result['ACC_overall'],
                 metrics_result['IoU_poly'], metrics_result['IoU_bg'], metrics_result['IoU_mean'],
                 metrics_result['Dice']))
예제 #6
0
    best_criterion = -1
    lr = args.lr
    train_loader, val_loader = get_train_dataloader(args)
    test_loader = get_test_dataloader(args)

    for epoch in range(1, args.epochs + 1):

        y_l = np.array([])
        y_p = np.array([])
        for param_group in optimizer.param_groups:
            current_lr = param_group['lr']
        time_cost, y_l, y_p = train(epoch, y_l, y_p)
        print(
            "====================Epoch:{}==================== Learning Rate:{:.5f}"
            .format(epoch, current_lr))
        SROCC, KROCC, PLCC, RMSE, Acc = evaluate(y_l, y_p)
        writer.add_scalar('Train/SROCC', SROCC, epoch)
        print(
            "Training Results - Epoch: {}  Avg accuracy: {:.3f} RMSE: {:.5f}  SROCC: {:.5f} KROCC: {:.5f} PLCC: {:.5f} ***** Time Cost: {:.1f} s"
            .format(epoch, Acc, RMSE, SROCC, KROCC, PLCC, time_cost))

        y_l = np.array([])
        y_p = np.array([])
        start = time.time()
        y_l, y_p = test(y_l, y_p)
        SROCC, KROCC, PLCC, RMSE, Acc = evaluate(y_l, y_p)
        end = time.time()
        writer.add_scalar('Test/LOSS', RMSE, epoch)
        writer.add_scalar('Test/SROCC', SROCC, epoch)
        print(
            "Testing Results - Epoch: {}  Avg accuracy: {:.3f} RMSE: {:.5f}  SROCC: {:.5f} KROCC: {:.5f} PLCC: {:.5f} ***** Time Cost: {:.1f} s"
예제 #7
0
def run_experiments(train_file_name,
                    test_file_name,
                    result_file_name,
                    forecast_horizon,
                    past_history_ls,
                    batch_size_ls,
                    epochs_ls,
                    tcn_params=TCN_PARAMS,
                    lstm_params=LSTM_PARAMS,
                    gpu_number=None,
                    metrics_ls=METRICS,
                    buffer_size=1000,
                    seed=1,
                    show_plots=False,
                    webhook=None,
                    validation_size=0.2):
    tf.random.set_seed(seed)
    np.random.seed(seed)

    gpus = tf.config.experimental.list_physical_devices('GPU')
    print(gpus)
    device_name = str(gpus)
    if len(gpus) >= 2 and gpu_number is not None:
        device = gpus[gpu_number]
        tf.config.experimental.set_memory_growth(device, True)
        tf.config.experimental.set_visible_devices(device, 'GPU')
        device_name = str(device)
        print(device)

    # Write result csv header
    current_index = 0
    try:
        with open(result_file_name, 'r') as resfile:
            current_index = sum(1 for line in resfile) - 1
    except IOError:
        pass
    print('CURRENT INDEX', current_index)
    if current_index == 0:
        with open(result_file_name, 'w') as resfile:
            resfile.write(';'.join([
                str(a) for a in [
                    'MODEL', 'MODEL_DESCRIPTION', 'FORECAST_HORIZON',
                    'PAST_HISTORY', 'BATCH_SIZE', 'EPOCHS'
                ] + metrics_ls + ['val_' + m for m in metrics_ls] +
                ['loss', 'val_loss', 'Execution_time', 'Device']
            ]) + "\n")

    # Read train file
    with open(train_file_name, 'r') as datafile:
        ts_train = datafile.readlines()[1:]  # skip the header
        ts_train = np.asarray([
            np.asarray(l.rstrip().split(',')[0], dtype=np.float32)
            for l in ts_train
        ])
        ts_train = np.reshape(ts_train, (ts_train.shape[0], ))

    # Read test data file
    with open(test_file_name, 'r') as datafile:
        ts_test = datafile.readlines()[1:]  # skip the header
        ts_test = np.asarray([
            np.asarray(l.rstrip().split(',')[0], dtype=np.float32)
            for l in ts_test
        ])
        ts_test = np.reshape(ts_test, (ts_test.shape[0], ))

    # Train/validation split
    TRAIN_SPLIT = int(ts_train.shape[0] * (1 - validation_size))
    print(ts_train.shape, TRAIN_SPLIT)
    # Normalize training data
    norm_params = normalization.get_normalization_params(
        ts_train[:TRAIN_SPLIT])
    ts_train = normalization.normalize(ts_train, norm_params)
    # Normalize test data with train params
    ts_test = normalization.normalize(ts_test, norm_params)

    i = 0
    index_1, total_1 = 0, len(
        list(itertools.product(past_history_ls, batch_size_ls, epochs_ls)))
    for past_history, batch_size, epochs in tqdm(
            list(itertools.product(past_history_ls, batch_size_ls,
                                   epochs_ls))):
        index_1 += 1
        # Get x and y for training and validation
        x_train, y_train = data_generation.univariate_data(
            ts_train, 0, TRAIN_SPLIT, past_history, forecast_horizon)
        x_val, y_val = data_generation.univariate_data(
            ts_train, TRAIN_SPLIT - past_history, ts_train.shape[0],
            past_history, forecast_horizon)
        print(x_train.shape, y_train.shape, '\n', x_val.shape, y_val.shape)
        # Get x and y for test data
        x_test, y_test = data_generation.univariate_data(
            ts_test, 0, ts_test.shape[0], past_history, forecast_horizon)

        # Convert numpy data to tensorflow dataset
        train_data = tf.data.Dataset.from_tensor_slices(
            (x_train,
             y_train)).cache().shuffle(buffer_size).batch(batch_size).repeat()
        val_data = tf.data.Dataset.from_tensor_slices((
            x_val,
            y_val)).batch(batch_size).repeat() if validation_size > 0 else None
        test_data = tf.data.Dataset.from_tensor_slices(
            (x_test, y_test)).batch(batch_size)

        # Create models
        model_list = {}
        model_description_list = {}
        if tcn_params is not None:
            model_list = {
                'TCN_{}'.format(j):
                (tcn,
                 [x_train.shape, forecast_horizon, 'adam', 'mae', *params])
                for j, params in enumerate(
                    itertools.product(*tcn_params.values()))
                if params[1] * params[2] * params[3][-1] == past_history
            }
            model_description_list = {
                'TCN_{}'.format(j): str(dict(zip(tcn_params.keys(), params)))
                for j, params in enumerate(
                    itertools.product(*tcn_params.values()))
                if params[1] * params[2] * params[3][-1] == past_history
            }
        if lstm_params is not None:
            model_list = {
                **model_list,
                **{
                    'LSTM_{}'.format(j): (lstm, [
                        x_train.shape, forecast_horizon, 'adam', 'mae', *params
                    ])
                    for j, params in enumerate(
                        itertools.product(*lstm_params.values()))
                }
            }
            model_description_list = {
                **model_description_list,
                **{
                    'LSTM_{}'.format(j): str(
                        dict(zip(lstm_params.keys(), params)))
                    for j, params in enumerate(
                        itertools.product(*lstm_params.values()))
                }
            }

        steps_per_epoch = int(np.ceil(x_train.shape[0] / batch_size))
        validation_steps = steps_per_epoch if val_data else None

        index_2, total_2 = 0, len(model_list.keys())
        for model_name, (model_function, params) in tqdm(model_list.items(),
                                                         position=1):
            index_2 += 1
            i += 1
            if i <= current_index:
                continue
            start = time.time()
            model = model_function(*params)
            print(model.summary())

            # Train the model
            history = model.fit(train_data,
                                epochs=epochs,
                                steps_per_epoch=steps_per_epoch,
                                validation_data=val_data,
                                validation_steps=validation_steps)

            # Plot training and evaluation loss evolution
            if show_plots:
                auxiliary_plots.plot_training_history(history, ['loss'])

            # Get validation results
            val_metrics = {}
            if validation_size > 0:
                val_forecast = model.predict(x_val)
                val_forecast = normalization.denormalize(
                    val_forecast, norm_params)
                y_val_denormalized = normalization.denormalize(
                    y_val, norm_params)

                val_metrics = metrics.evaluate(y_val_denormalized,
                                               val_forecast, metrics_ls)
                print('Validation metrics', val_metrics)

            # TEST
            # Predict with test data and get results
            test_forecast = model.predict(test_data)

            test_forecast = normalization.denormalize(test_forecast,
                                                      norm_params)
            y_test_denormalized = normalization.denormalize(
                y_test, norm_params)
            x_test_denormalized = normalization.denormalize(
                x_test, norm_params)

            test_metrics = metrics.evaluate(y_test_denormalized, test_forecast,
                                            metrics_ls)
            print('Test scores', test_metrics)

            # Plot some test predictions
            if show_plots:
                auxiliary_plots.plot_ts_forecasts(x_test_denormalized,
                                                  y_test_denormalized,
                                                  test_forecast)

            # Save results
            val_metrics = {'val_' + k: val_metrics[k] for k in val_metrics}
            model_metric = {
                'MODEL': model_name,
                'MODEL_DESCRIPTION': model_description_list[model_name],
                'FORECAST_HORIZON': forecast_horizon,
                'PAST_HISTORY': past_history,
                'BATCH_SIZE': batch_size,
                'EPOCHS': epochs,
                **test_metrics,
                **val_metrics,
                **history.history, 'Execution_time': time.time() - start,
                'Device': device_name
            }

            notify_slack('Progress: {0}/{1} ({2}/{3}) \nMetrics:{4}'.format(
                index_1, total_1, index_2, total_2,
                str({
                    'Model':
                    model_name,
                    'WAPE':
                    str(test_metrics['wape']),
                    'Execution_time':
                    "{0:.2f}  seconds".format(time.time() - start)
                })),
                         webhook=webhook)

            with open(result_file_name, 'a') as resfile:
                resfile.write(';'.join([str(a)
                                        for a in model_metric.values()]) +
                              "\n")
예제 #8
0
    optim.step()

    print("[Epoch %3d] Loss : %.4f" % (epoch, loss))



# Evaluation

# Choose evaluation metrics for implicit (top-k)
do_prec, do_recall, do_ndcg = True, True, True
test_nums = [1, 5, 10]
test_matrix = build_matrix(test[0], test[1], test[2], num_users, num_items)

# Prediction
logit = model(train_matrix)

# Mask to ignore train data
# Same shape with full matrix
# 1 if training data, 0 otherwise
mask = train_matrix > 0

if cuda:
    logit = logit.cpu().detach()

# Evaluate prediction
# Explicit => MSE
# Implicit => precision, recall, ndcg
result = evaluate(logit, test_matrix, mask, metrics=(do_prec, do_recall, do_ndcg), nums=test_nums, implicit=implicit)

# Print result
print_result(result, test_nums, implicit)
예제 #9
0
def evaluate_single_epoch(config, model, dataloader, criterion, epoch):
    batch_time = AverageMeter()

    model.eval()

    all_logits = []
    all_labels = []

    with torch.no_grad():
        end = time.time()
        for i, (images, labels) in enumerate(dataloader):
            images = images.cuda()
            labels = labels.cuda()

            logits = model(images)

            labels = labels.squeeze(1)

            all_logits.append(logits)
            all_labels.append(labels)

            batch_time.update(time.time() - end)
            end = time.time()

            if i % config.PRINT_EVERY == 0:
                print('[%2d/%2d] time: %.2f' %
                      (i, len(dataloader), batch_time.sum))

            del images, labels, logits
            torch.cuda.empty_cache()

        all_logits = torch.cat(all_logits, dim=0)
        all_labels = torch.cat(all_labels, dim=0)
        loss = criterion(all_logits, all_labels)

        all_preds = all_logits.argmax(dim=1)
        all_preds = all_preds.detach().cpu().numpy()
        all_labels = all_labels.detach().cpu().numpy()
        res, class_sens, class_spec, ttl_acc, ttl_sens, ttl_spec = evaluate(
            all_preds, all_labels)
        print(
            ' | %12s | %.4f |\n' % ('loss', loss.item()),
            '| %12s | %.10f |\n' % ('res', res),
            # '| %12s | %.4f %.4f %.4f %.4f |\n' % ('class_acc', class_acc[0], class_acc[1], class_acc[2], class_acc[3]),
            '| %12s | %.4f %.4f %.4f %.4f |\n' %
            ('class_sens', class_sens[0], class_sens[1], class_sens[2],
             class_sens[3]),
            '| %12s | %.4f %.4f %.4f %.4f |\n' %
            ('class_spec', class_spec[0], class_spec[1], class_spec[2],
             class_spec[3]),
            '| %12s | %.4f |\n' % ('acc', ttl_acc),
            '| %12s | %.4f |\n' % ('sens', ttl_sens),
            '| %12s | %.4f |\n' % ('spec', ttl_spec),
        )

        nb_classes = 4
        conf_matrix = np.zeros((nb_classes, nb_classes))
        for t, p in zip(all_labels, all_preds):
            conf_matrix[t, p] += 1
        # 세로축이 정답지, 가로축이 예측
        print('Confusion Matrix')
        print(conf_matrix)
        print()

    return loss.item(), res, ttl_acc, ttl_sens, ttl_spec
예제 #10
0
def run_training():
    graph = tf.Graph()
    with graph.as_default():
        # Input data.
        y = tf.placeholder(tf.float32, shape=[None, y_dim], name='condition')
        x = tf.placeholder(tf.float32, shape=[None, x_dim], name='target')
        is_training = tf.placeholder(tf.bool, shape=[])

        y1, y2, y3 = tf.split(y, 3, 0)
        x1, x2, x3 = tf.split(x, 3, 0)

        with tf.variable_scope('Gen') as scope:
            G_sample = generator(y1, reuse=None, is_training=is_training)
            scope.reuse_variables()
            G_representation = generator(y,
                                         reuse=True,
                                         is_training=is_training)
        with tf.variable_scope('Disc') as scope:
            D_real = discriminator(x1, y1, reuse=None)
            scope.reuse_variables()
            D_fake = discriminator(G_sample, y1, reuse=True)

            D_real2 = discriminator(x1, y2, reuse=True)
            D_fake2 = discriminator(G_sample, y2, reuse=True)

            D_real3 = discriminator(x2, y1, reuse=True)
            D_fake3 = discriminator(G_sample, y1, reuse=True)

            D_consx = discriminator(x3, y1, reuse=True)
            D_consy = discriminator(x1, y3, reuse=True)

        # loss of discriminator
        loss_d = 1*(-tf.reduce_mean(D_real) + tf.reduce_mean(D_fake)) + \
                 opt.beta*(1*tf.reduce_mean(D_consx) + 1*tf.reduce_mean(D_consy)) + \
                 1 *(-tf.reduce_mean(D_real2) + tf.reduce_mean(D_fake2)) + \
                 1 * (-tf.reduce_mean(D_real3))
        # loss of generator
        loss_g = -1 * tf.reduce_mean(D_fake) - 1 * tf.reduce_mean(D_fake2)

        # loss of MSE
        loss_eq = tf.nn.l2_loss(G_sample - x1)

        l2_loss = tf.losses.get_regularization_loss()

        tf.summary.histogram("D_real", D_real)
        tf.summary.histogram("D_fake", D_fake)
        tf.summary.histogram("G_sample", G_sample)

        d_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                     scope='Disc')
        g_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                     scope='Gen')
        opt_d = tf.train.RMSPropOptimizer(learning_rate=1e-4).minimize(
            loss_d, var_list=d_params)
        opt_g = tf.train.RMSPropOptimizer(learning_rate=1e-4).minimize(
            loss_g + opt.gamma * loss_eq + l2_loss, var_list=g_params)
        clip_d = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in d_params]

        # Add variable initializer.
        init = tf.global_variables_initializer()

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

        # Create a saver for writing training checkpoints.
        saver = tf.train.Saver()

    train_data_set, test_data_set = my_input.get_data(tt='cca')

    # Begin training.
    with tf.Session(graph=graph) as sess:

        # We must initialize all variables before we use them.
        sess.run(init)
        print("Initialized")

        # Instantiate a SummaryWriter to output summaries and the Graph.
        summary_writer = tf.summary.FileWriter(opt.log_dir, sess.graph)

        average_loss_d = 0
        average_loss_g = 0
        start_time = time.time()
        time_sum = 0
        with open(os.path.join(opt.log_dir, 'record.txt'), 'a') as f:
            f.write('-' * 30 + '\n')

        for step in range(opt.max_steps):
            # update discriminator
            for _ in range(opt.critic_itrs):
                batch_image, batch_text, batch_label = train_data_set.next_batch(
                    opt.batch_size)
                if opt.dir == 'txt2img':
                    _, loss_val_d = sess.run([opt_d, loss_d],
                                             feed_dict={
                                                 is_training: True,
                                                 y: batch_text,
                                                 x: batch_image
                                             })
                else:
                    _, loss_val_d = sess.run([opt_d, loss_d],
                                             feed_dict={
                                                 is_training: True,
                                                 y: batch_image,
                                                 x: batch_text
                                             })
                sess.run(clip_d)
                average_loss_d += loss_val_d

            # update generator
            batch_image, batch_text, batch_label = train_data_set.next_batch(
                opt.batch_size)
            if opt.dir == 'txt2img':
                _, loss_val_g = sess.run([opt_g, loss_g],
                                         feed_dict={
                                             is_training: True,
                                             y: batch_text,
                                             x: batch_image
                                         })
            else:
                _, loss_val_g = sess.run([opt_g, loss_g],
                                         feed_dict={
                                             is_training: True,
                                             y: batch_image,
                                             x: batch_text
                                         })
            average_loss_g += loss_val_g

            # Write the summaries and print an overview fairly often.
            if (step + 1) % opt.log_interval == 0:
                average_loss_d /= (opt.log_interval * opt.critic_itrs)
                average_loss_g /= opt.log_interval
                duration = time.time() - start_time
                print(
                    'Step %d: average_loss_d = %.5f, average_loss_g = %.5f (%.3f sec)'
                    % (step, average_loss_d, average_loss_g, duration))
                average_loss_d = 0
                average_loss_g = 0

                if opt.dir == 'txt2img':
                    summary_str = sess.run(summary_op,
                                           feed_dict={
                                               y: batch_text,
                                               x: batch_image
                                           })
                else:
                    summary_str = sess.run(summary_op,
                                           feed_dict={
                                               y: batch_image,
                                               x: batch_text
                                           })
                summary_writer.add_summary(summary_str, step)
                summary_writer.flush()

            # Save a checkpoint and evaluate the model periodically.
            if (step + 1) % opt.save_interval == 0 or (step +
                                                       1) == opt.max_steps:
                save_path = saver.save(sess,
                                       os.path.join(opt.save_dir,
                                                    "model"), step)
                print("Model saved in file: %s" % save_path)
                batch_image, batch_text, batch_label = test_data_set.next_batch(
                )
                if opt.dir == 'txt2img':
                    feed_dict = {
                        is_training: False,
                        y: batch_text,
                        x: batch_image
                    }
                    [text_representation,
                     image_representation] = sess.run([G_representation, x],
                                                      feed_dict=feed_dict)
                else:
                    feed_dict = {
                        is_training: False,
                        y: batch_image,
                        x: batch_text
                    }
                    [image_representation,
                     text_representation] = sess.run([G_representation, x],
                                                     feed_dict=feed_dict)
                duration = time.time() - start_time
                time_sum = time_sum + duration
                map_i2t, map_t2i = metrics.evaluate(image_representation,
                                                    text_representation,
                                                    batch_label,
                                                    metric='cos')
                # np.savez("coco_pre_%d.npz" % (step + 1), img_proj_te=image_representation, txt_proj_te=batch_text,
                #          label_te=batch_label)
                start_time = time.time()
                with open(os.path.join(opt.log_dir, 'record.txt'), 'a') as f:
                    f.write('%d %f %f %.3f\n' %
                            (step + 1, map_i2t[-1], map_t2i[-1], time_sum))