Example #1
0
 def grid_search(self, estimator, param_grid, features, targets):
     print("\nGrid search for algorithm:  {}".format(estimator))
     cv = StratifiedShuffleSplit(n_splits=3, test_size=0.2, random_state=42)
     grid = GridSearchCV(estimator=estimator,
                         param_grid=param_grid,
                         cv=cv,
                         verbose=10,
                         n_jobs=6)
     grid.fit(features, targets)
     print("The best parameters are %s with a score of %0.2f" %
           (grid.best_params_, grid.best_score_))
     return grid
Example #2
0
def main():
    global args, best_err1, best_err5
    args = parser.parse_args()

    if args.dataset.startswith('cifar'):
        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])

        autoaug = args.autoaug
        if autoaug:
            print('augmentation: %s' % autoaug)
            if autoaug == 'fa_reduced_cifar10':
                transform_train.transforms.insert(
                    0, Augmentation(fa_reduced_cifar10()))
            elif autoaug == 'fa_reduced_imagenet':
                transform_train.transforms.insert(
                    0, Augmentation(fa_reduced_imagenet()))
            elif autoaug == 'autoaug_cifar10':
                transform_train.transforms.insert(
                    0, Augmentation(autoaug_paper_cifar10()))
            elif autoaug == 'autoaug_extend':
                transform_train.transforms.insert(
                    0, Augmentation(autoaug_policy()))
            elif autoaug in ['default', 'inception', 'inception320']:
                pass
            else:
                raise ValueError('not found augmentations. %s' %
                                 C.get()['aug'])

        transform_test = transforms.Compose([transforms.ToTensor(), normalize])

        if args.dataset == 'cifar100':
            ds_train = datasets.CIFAR100(args.cifarpath,
                                         train=True,
                                         download=True,
                                         transform=transform_train)
            if args.cv >= 0:
                sss = StratifiedShuffleSplit(n_splits=5,
                                             test_size=0.2,
                                             random_state=0)
                sss = sss.split(list(range(len(ds_train))), ds_train.targets)
                for _ in range(args.cv + 1):
                    train_idx, valid_idx = next(sss)
                ds_valid = Subset(ds_train, valid_idx)
                ds_train = Subset(ds_train, train_idx)
            else:
                ds_valid = Subset(ds_train, [])
            ds_test = datasets.CIFAR100(args.cifarpath,
                                        train=False,
                                        transform=transform_test)

            train_loader = torch.utils.data.DataLoader(
                CutMix(ds_train,
                       100,
                       beta=args.cutmix_beta,
                       prob=args.cutmix_prob,
                       num_mix=args.cutmix_num),
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True)
            tval_loader = torch.utils.data.DataLoader(
                ds_valid,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=True)
            val_loader = torch.utils.data.DataLoader(
                ds_test,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=True)
            numberofclass = 100
        elif args.dataset == 'cifar10':
            ds_train = datasets.CIFAR10(args.cifarpath,
                                        train=True,
                                        download=True,
                                        transform=transform_train)
            if args.cv >= 0:
                sss = StratifiedShuffleSplit(n_splits=5,
                                             test_size=0.2,
                                             random_state=0)
                sss = sss.split(list(range(len(ds_train))), ds_train.targets)
                for _ in range(args.cv + 1):
                    train_idx, valid_idx = next(sss)
                ds_valid = Subset(ds_train, valid_idx)
                ds_train = Subset(ds_train, train_idx)
            else:
                ds_valid = Subset(ds_train, [])

            train_loader = torch.utils.data.DataLoader(
                CutMix(ds_train,
                       10,
                       beta=args.cutmix_beta,
                       prob=args.cutmix_prob,
                       num_mix=args.cutmix_num),
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True)
            tval_loader = torch.utils.data.DataLoader(
                ds_valid,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=args.workers,
                pin_memory=True)
            val_loader = torch.utils.data.DataLoader(
                datasets.CIFAR10(args.cifarpath,
                                 train=False,
                                 transform=transform_test),
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True)
            numberofclass = 10
        else:
            raise Exception('unknown dataset: {}'.format(args.dataset))

    elif args.dataset == 'imagenet':
        traindir = os.path.join(args.imagenetpath, 'train')
        valdir = os.path.join(args.imagenetpath, 'val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        jittering = utils.ColorJitter(brightness=0.4,
                                      contrast=0.4,
                                      saturation=0.4)
        lighting = utils.Lighting(alphastd=0.1,
                                  eigval=[0.2175, 0.0188, 0.0045],
                                  eigvec=[[-0.5675, 0.7192, 0.4009],
                                          [-0.5808, -0.0045, -0.8140],
                                          [-0.5836, -0.6948, 0.4203]])

        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            jittering,
            lighting,
            normalize,
        ])

        autoaug = args.autoaug
        if autoaug:
            print('augmentation: %s' % autoaug)
            if autoaug == 'fa_reduced_cifar10':
                transform_train.transforms.insert(
                    0, Augmentation(fa_reduced_cifar10()))
            elif autoaug == 'fa_reduced_imagenet':
                transform_train.transforms.insert(
                    0, Augmentation(fa_reduced_imagenet()))

            elif autoaug == 'autoaug_cifar10':
                transform_train.transforms.insert(
                    0, Augmentation(autoaug_paper_cifar10()))
            elif autoaug == 'autoaug_extend':
                transform_train.transforms.insert(
                    0, Augmentation(autoaug_policy()))
            elif autoaug in ['default', 'inception', 'inception320']:
                pass
            else:
                raise ValueError('not found augmentations. %s' %
                                 C.get()['aug'])

        train_dataset = datasets.ImageFolder(traindir, transform_train)
        if args.cv >= 0:
            sss = StratifiedShuffleSplit(n_splits=5,
                                         test_size=0.2,
                                         random_state=0)
            sss = sss.split(list(range(len(train_dataset))),
                            train_dataset.targets)
            for _ in range(args.cv + 1):
                train_idx, valid_idx = next(sss)
            valid_dataset = Subset(train_dataset, valid_idx)
            train_dataset = Subset(train_dataset, train_idx)
        else:
            valid_dataset = Subset(train_dataset, [])

        train_dataset = CutMix(train_dataset,
                               1000,
                               beta=args.cutmix_beta,
                               prob=args.cutmix_prob,
                               num_mix=args.cutmix_num)
        train_sampler = None

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            num_workers=args.workers,
            pin_memory=True,
            sampler=train_sampler)
        tval_loader = torch.utils.data.DataLoader(valid_dataset,
                                                  batch_size=args.batch_size,
                                                  shuffle=False,
                                                  num_workers=args.workers,
                                                  pin_memory=True)
        val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
        numberofclass = 1000
    else:
        raise Exception('unknown dataset: {}'.format(args.dataset))

    print("=> creating model '{}'".format(args.net_type))
    if args.net_type == 'resnet':
        model = RN.ResNet(args.dataset, args.depth, numberofclass, True)
    elif args.net_type == 'pyramidnet':
        model = PYRM.PyramidNet(args.dataset, args.depth, args.alpha,
                                numberofclass, True)
    elif 'wresnet' in args.net_type:
        model = WRN(args.depth,
                    args.alpha,
                    dropout_rate=0.0,
                    num_classes=numberofclass)
    else:
        raise ValueError('unknown network architecture: {}'.format(
            args.net_type))

    model = torch.nn.DataParallel(model).cuda()
    print('the number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # define loss function (criterion) and optimizer
    criterion = CutMixCrossEntropyLoss(True)
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=1e-4,
                                nesterov=True)
    cudnn.benchmark = True

    for epoch in range(0, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        model.train()
        err1, err5, train_loss = run_epoch(train_loader, model, criterion,
                                           optimizer, epoch, 'train')
        train_err1 = err1
        err1, err5, train_loss = run_epoch(tval_loader, model, criterion, None,
                                           epoch, 'train-val')

        # evaluate on validation set
        model.eval()
        err1, err5, val_loss = run_epoch(val_loader, model, criterion, None,
                                         epoch, 'valid')

        # remember best prec@1 and save checkpoint
        is_best = err1 <= best_err1
        best_err1 = min(err1, best_err1)
        if is_best:
            best_err5 = err5
            print('Current Best (top-1 and 5 error):', best_err1, best_err5)

        save_checkpoint(
            {
                'epoch': epoch,
                'arch': args.net_type,
                'state_dict': model.state_dict(),
                'best_err1': best_err1,
                'best_err5': best_err5,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            filename='checkpoint_e%d_top1_%.3f_%.3f.pth' %
            (epoch, train_err1, err1))

    print('Best(top-1 and 5 error):', best_err1, best_err5)
    print("Best parameters set:")
    print(grid_search_tune.best_estimator_.steps)
    print()

    # measuring performance on test set
    print("Applying best classifier on test data:")
    best_clf = grid_search_tune.best_estimator_
    predictions = best_clf.predict(test_x)

    print(classification_report(test_y, predictions, target_names=genres))


genres = list(data_df.drop(['title', 'plot'], axis=1).columns.values)
data_x = data_df[['plot']].as_matrix()
data_y = data_df.drop(['title', 'plot'], axis=1).as_matrix()
stratified_split = StratifiedShuffleSplit(n_splits=2, test_size=0.33)

x_train, x_test, y_train, y_test = train_test_split(data_x,
                                                    data_y,
                                                    test_size=0.33,
                                                    random_state=42)

# transform matrix of plots into lists to pass to a TfidfVectorizer
train_x = [x[0].strip() for x in x_train.tolist()]
test_x = [x[0].strip() for x in x_test.tolist()]

stop_words = set(stopwords.words('english'))

## http://michelleful.github.io/code-blog/2015/06/20/pipelines/
## learn feature union to add more features (time, region)
    if args.corrupt != '':
        corrupt_type, corrupt_level = args.corrupt.split(':')
        corrupt_level = int(corrupt_level)
        print(f'corruption {corrupt_type} : {corrupt_level}')

        from imagenet_c import corrupt
        if not corrupt_type.isdigit():
            ts.insert(corrupt_idx, lambda img: PIL.Image.fromarray(corrupt(np.array(img), corrupt_level, corrupt_type)))
        else:
            ts.insert(corrupt_idx, lambda img: PIL.Image.fromarray(corrupt(np.array(img), corrupt_level, None, int(corrupt_type))))

    transform_test = transforms.Compose(ts)

    testset = ImageNet(root='/data/public/rw/datasets/imagenet-pytorch', split='val', transform=transform_test)
    sss = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=0)
    for _ in range(1):
        sss = sss.split(list(range(len(testset))), testset.targets)
    train_idx, valid_idx = next(sss)
    testset = Subset(testset, valid_idx)

    testloader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch, shuffle=False, num_workers=32, pin_memory=True, drop_last=False)

    metric = Accumulator()
    dl_test = tqdm(testloader)
    data_id = 0
    tta_rule_cnt = [0] * tta_num
    for data, label in dl_test:
        data = data.view(-1, data.shape[-3], data.shape[-2], data.shape[-1])
        data = data.cuda()
Example #5
0
def main():

    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    cudnn.enabled = True
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)
    cur_epoch = 0

    net = eval(args.arch)
    print(net)
    code = gen_code_from_list(net, node_num=int((len(net) / 4)))
    genotype = translator([code, code], max_node=int((len(net) / 4)))
    print(genotype)

    model_ema = None

    if not continue_train:

        print('train from the scratch')
        model = Network(args.init_ch, 10, args.layers, args.auxiliary,
                        genotype).cuda()
        print("model init params values:", flatten_params(model))

        logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

        criterion = CutMixCrossEntropyLoss(True).cuda()

        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.wd)

        if args.model_ema:
            model_ema = ModelEma(
                model,
                decay=args.model_ema_decay,
                device='cpu' if args.model_ema_force_cpu else '')

    else:
        print('continue train from checkpoint')

        model = Network(args.init_ch, 10, args.layers, args.auxiliary,
                        genotype).cuda()

        criterion = CutMixCrossEntropyLoss(True).cuda()

        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.wd)

        logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

        checkpoint = torch.load(args.save + '/model.pt')
        model.load_state_dict(checkpoint['model_state_dict'])
        cur_epoch = checkpoint['epoch']
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        if args.model_ema:

            model_ema = ModelEma(
                model,
                decay=args.model_ema_decay,
                device='cpu' if args.model_ema_force_cpu else '',
                resume=args.save + '/model.pt')

    train_transform, valid_transform = utils._auto_data_transforms_cifar10(
        args)

    ds_train = dset.CIFAR10(root=args.data,
                            train=True,
                            download=True,
                            transform=train_transform)

    args.cv = -1
    if args.cv >= 0:
        sss = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=0)
        sss = sss.split(list(range(len(ds_train))), ds_train.targets)
        for _ in range(args.cv + 1):
            train_idx, valid_idx = next(sss)
        ds_valid = Subset(ds_train, valid_idx)
        ds_train = Subset(ds_train, train_idx)
    else:
        ds_valid = Subset(ds_train, [])

    train_queue = torch.utils.data.DataLoader(CutMix(ds_train,
                                                     10,
                                                     beta=1.0,
                                                     prob=0.5,
                                                     num_mix=2),
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=2,
                                              pin_memory=True)

    valid_queue = torch.utils.data.DataLoader(dset.CIFAR10(
        root=args.data, train=False, transform=valid_transform),
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=2,
                                              pin_memory=True)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs))

    best_acc = 0.0

    if continue_train:
        for i in range(cur_epoch + 1):
            scheduler.step()

    for epoch in range(cur_epoch, args.epochs):
        print('cur_epoch is', epoch)
        scheduler.step()
        logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
        model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

        if model_ema is not None:
            model_ema.ema.drop_path_prob = args.drop_path_prob * epoch / args.epochs

        train_acc, train_obj = train(train_queue, model, criterion, optimizer,
                                     epoch, model_ema)
        logging.info('train_acc: %f', train_acc)

        if model_ema is not None and not args.model_ema_force_cpu:
            valid_acc_ema, valid_obj_ema = infer(valid_queue,
                                                 model_ema.ema,
                                                 criterion,
                                                 ema=True)
            logging.info('valid_acc_ema %f', valid_acc_ema)

        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        logging.info('valid_acc: %f', valid_acc)

        if valid_acc > best_acc:
            best_acc = valid_acc
            print('this model is the best')
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict()
                }, os.path.join(args.save, 'top1.pt'))
        print('current best acc is', best_acc)
        logging.info('best_acc: %f', best_acc)

        if model_ema is not None:
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'state_dict_ema': get_state_dict(model_ema)
                }, os.path.join(args.save, 'model.pt'))

        else:
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict()
                }, os.path.join(args.save, 'model.pt'))

        print('saved to: trained.pt')
Example #6
0
def train_test_split(*arrays,
                     test_size=None,
                     train_size=None,
                     shuffle=True,
                     random_split=None,
                     type='dists'):
    """
    This function is a modification of sklearn.model_selection.train_test_split(). Here, we split the data into
    into train and test sets, and then augment the train data
    :param arrays:
    :param test_size:
    :param train_size:
    :param shuffle:
    :param random_split:
    :param type: The type of data that the Augmentor shall perceive
    :return:
    """
    n_arrays = len(arrays)
    if n_arrays == 0:
        raise ValueError("At least one array required as input")

    arrays = indexable(*arrays)
    n_samples = _num_samples(arrays[0])
    # print(f"n_samples: {n_samples}")
    n_train, n_test = _validate_shuffle_split(n_samples,
                                              test_size,
                                              train_size,
                                              default_test_size=0.1)
    # print(n_train, n_test)
    if shuffle is False:
        train = np.arange(n_train)
        test = np.arange(n_train, n_train + n_test)
    else:
        cv = StratifiedShuffleSplit(test_size=n_test,
                                    train_size=n_train,
                                    random_state=random_split)
        if len(arrays) >= 2:
            y = arrays[1]
        else:
            y = None
        train, test = next(cv.split(X=arrays[0], y=y))

    # print(f"trains: {train}, test: {test}")
    List = list(
        chain.from_iterable((_safe_indexing(a, train), _safe_indexing(a, test))
                            for a in arrays))

    ###############################################################################
    # What follows is BAD CODE: I'm assuming that the inputs arrays (*arrays)
    # will always be of the form X, y, where X is the training data and y are the
    # labels.
    ###############################################################################

    X_train = List[0]
    y_train = List[2]
    X_train_final = []
    y_train_final = []
    for i in range(len(X_train)):
        X_train_augmentor = Augmentor(X_train[i], type)
        X_train_augmented = X_train_augmentor.get_augmented_data()
        # print(f"the datatype of y:{y_train.dtype}, y itself: {y_train[i]} ")
        # print(X_train_augmented.shape)
        y = np.ones(12) * y_train[i]
        X_train_final = X_train_final + [x for x in X_train_augmented]
        y_train_final = y_train_final + [y_i for y_i in y]

    X_train_final = np.array(X_train_final)
    y_train_final = np.array(y_train_final)

    return X_train_final, List[1], y_train_final, List[-1]
Example #7
0
def main(_):
    if (FLAGS.config is None):
        config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                   'model_config.yml')
    else:
        config_file = FLAGS.config
    with open(config_file, 'r') as rf:
        params = yaml.load(rf)

    seed = params.get('seed')
    random_state = np.random.RandomState(seed)
    tf.set_random_seed(seed)

    data_dir = params.get('data_dir')
    model_dir = params.get('model_dir')
    experiment_name = params.get('experiment_name')
    train_data_filename = params.get('train_file')
    test_data_filename = params.get('test_file')

    #load sentences data
    print("loading data...", flush=True)
    train_data_file = os.path.join(data_dir, train_data_filename)
    test_data_file = os.path.join(data_dir, test_data_filename)
    dftrain = load_sents_data_semeval2010(train_data_file)
    dftest = load_sents_data_semeval2010(test_data_file, testset=True)
    dftraintest = pd.concat([dftrain, dftest],
                            ignore_index=True).reset_index(drop=True)
    le = LabelEncoder().fit(dftrain.class_.values)
    params['nclass'] = len(le.classes_)
    params['label_encoder_file'] = experiment_name + '_label_encoder.pkl'
    #oversample class w/ only one example, hack for stratified cv
    dftrain = pd.concat(
        [dftrain, dftrain[dftrain.rel == 'ENTITY-DESTINATION(E2,E1)']],
        ignore_index=True).reset_index(drop=True)

    #build vocab
    print("building vocab...", flush=True)
    vocab_list = build_vocab(dftraintest)
    vocab_size = len(vocab_list)
    vocab_dict = dict(zip(vocab_list, range(vocab_size)))
    vocab_inv_dict = dict(zip(range(vocab_size), vocab_list))
    vocab = Vocab(vocab_list, vocab_size, vocab_dict, vocab_inv_dict)
    params['vocab_file'] = experiment_name + '_vocab.pkl'

    #read embeddings
    print("reading embeddings...", flush=True)
    vocab_vec = read_embeddings(params['embeddings.file'], vocab.words,
                                params['embeddings.init_scale'],
                                params['dtype'], random_state)
    embeddings_mat = np.asarray(vocab_vec.values, dtype=params['dtype'])
    embeddings_mat[0, :] = 0  #make embeddings of PADDING all zeros
    params['embeddings.mat.file'] = experiment_name + '_embeddings.pkl'

    #save params, vocab and embeddings in model directory for testing
    print("saving params, vocab, le and embeddings...", flush=True)
    with open(os.path.join(model_dir, experiment_name + '_params.yml'),
              'w') as wf:
        yaml.dump(params, wf, default_flow_style=False)
    with open(os.path.join(model_dir, params.get('vocab_file')), 'wb') as wf:
        pickle.dump(vocab, wf)
    with open(os.path.join(model_dir, params.get('embeddings.mat.file')),
              'wb') as wf:
        pickle.dump(embeddings_mat, wf)
    with open(os.path.join(model_dir, params.get('label_encoder_file')),
              'wb') as wf:
        pickle.dump(le, wf)

    ##cross-validation
    sss = StratifiedShuffleSplit(n_splits=1,
                                 random_state=random_state,
                                 test_size=params.get('devset_size'))
    for trainidx, devidx in sss.split(dftrain.values, dftrain.rel.values):
        cvtraindf = dftrain.iloc[trainidx, :]
        cvdevdf = dftrain.iloc[devidx, :]
        experiment_name = params.get('experiment_name')

        tstream = build_data_streams(cvtraindf, vocab.dict,
                                     params.get('sent_length'), le)
        dstream = build_data_streams(cvdevdf, vocab.dict,
                                     params.get('sent_length'), le)

        print("Training Data Shape: ", cvtraindf.shape)
        print("Dev Data Shape: ", cvdevdf.shape)
        print("Classes: ", le.classes_)

        def graph_ops():
            #2. build model and define its loss minimization approach(training operation)
            mdl = build_model(params)

            ##defining an optimizer to minimize model's loss
            global_step = tf.Variable(0, name="global_step", trainable=False)
            learning_rate = params.get('learning_rate')
            optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                                   momentum=0.8)
            train_op = optimizer.minimize(mdl.loss, global_step=global_step)

            # Summaries for loss & metrics
            loss_summary = tf.summary.scalar("loss", mdl.loss)
            acc_summary = tf.summary.scalar("accuracy", mdl.accuracy)

            init_op = tf.global_variables_initializer()
            saver = tf.train.Saver(tf.global_variables(), max_to_keep=10)

            return mdl, global_step, train_op, loss_summary, acc_summary, init_op, \
                saver

        with tf.Session() as sess:
            mdl, global_step, train_op, loss_summary, acc_summary, init_op, \
            saver = graph_ops()
            sess.run(init_op)

            #summaries
            ##train  summaries
            train_summary_dir = os.path.join(model_dir, "summaries",
                                             experiment_name, "train")
            train_summary_op = tf.summary.merge([loss_summary, acc_summary])
            train_summary_writer = tf.summary.FileWriter(train_summary_dir,
                                                         sess.graph,
                                                         flush_secs=3)
            ##dev summaries
            dev_summary_op = tf.summary.merge([loss_summary, acc_summary])
            dev_summary_dir = os.path.join(model_dir, "summaries",
                                           experiment_name, "dev")
            dev_summary_writer = tf.summary.FileWriter(dev_summary_dir,
                                                       sess.graph,
                                                       flush_secs=3)

            # train step
            def train_epoch():
                ntrain = tstream.sent.shape[0]
                bsize = params.get('batch_size')
                start = 0
                end = 0
                for start in range(0, ntrain, bsize):
                    end = start + bsize
                    if end > ntrain:
                        end = ntrain

                    train_feed_dict = {
                        mdl.sent: tstream.sent[start:end, :],
                        mdl.label: tstream.label[start:end],
                        mdl.ent1_dist: tstream.ent1_dist[start:end, :],
                        mdl.ent2_dist: tstream.ent2_dist[start:end, :],
                        mdl.dropout_keep_proba: params.get('dropout'),
                        mdl.batch_size: end - start
                    }
                    sess.run([train_op, global_step, mdl.loss],
                             train_feed_dict)

            def train_eval_step():
                sess.run(mdl.running_vars_initializer)
                train_feed_dict = {
                    mdl.sent: tstream.sent,
                    mdl.label: tstream.label,
                    mdl.ent1_dist: tstream.ent1_dist,
                    mdl.ent2_dist: tstream.ent2_dist,
                    mdl.dropout_keep_proba: 1.0,
                    mdl.batch_size: tstream.sent.shape[0]
                }
                tstep, tloss = sess.run([global_step, mdl.loss],
                                        train_feed_dict)
                sess.run(mdl.accuracy_op, train_feed_dict)
                tsummary = sess.run(train_summary_op, train_feed_dict)
                train_summary_writer.add_summary(tsummary, tstep)
                train_eval_score = sess.run(mdl.accuracy)
                return tstep, tloss, train_eval_score

            def eval_step():
                sess.run(mdl.running_vars_initializer)
                dev_feed_dict = {
                    mdl.sent: dstream.sent,
                    mdl.label: dstream.label,
                    mdl.ent1_dist: dstream.ent1_dist,
                    mdl.ent2_dist: dstream.ent2_dist,
                    mdl.dropout_keep_proba: 1.0,
                    mdl.batch_size: dstream.label.shape[0]
                }

                dstep, dloss, preds = sess.run(
                    [global_step, mdl.loss, mdl.preds], dev_feed_dict)
                sess.run(mdl.accuracy_op, dev_feed_dict)
                dacc_ = sess.run(mdl.accuracy)
                l = dstream.label
                p = preds

                class_int_labels = list(range(len(le.classes_)))
                target_names = le.classes_

                sess.run(mdl.accuracy_op, dev_feed_dict)
                dsummary = sess.run(dev_summary_op, dev_feed_dict)
                dev_summary_writer.add_summary(dsummary, dstep)
                eval_score = (f1_score(l, p, average='micro'),
                              f1_score(l, p, average='macro'), dacc_)
                print(
                    "EVAL step {}, loss {:g}, f1_micro {:g} f1_macro {:g} accuracy {:g}"
                    .format(tstep, dloss, eval_score[0], eval_score[1],
                            eval_score[2]),
                    flush=True)
                official_score = eval_score[1]

                print("Classification Report: \n%s" % classification_report(
                    l,
                    p,
                    labels=class_int_labels,
                    target_names=target_names,
                ),
                      flush=True)

                return official_score

            #training loop
            best_score = 0.0
            best_step = 0
            best_itr = 0
            for ite in range(params.get('training_iters')):
                train_epoch()
                if ite % params.get('train_step_eval') == 0:
                    tstep, tloss, tacc_ = train_eval_step()

                if ite % params.get('train_step_eval') == 0:
                    print(
                        "TRAIN step {}, iteration {} loss {:g} accuracy {:g}".
                        format(tstep, ite, tloss, tacc_),
                        flush=True)

                current_step = tf.train.global_step(sess, global_step)
                if current_step % params.get('eval_interval') == 0:
                    official_score = eval_step()
                    if best_score < official_score:
                        checkpoint_prefix = os.path.join(
                            params.get('model_dir'), "%s-score-%s" %
                            (experiment_name, str(official_score)))
                        saver.save(sess,
                                   checkpoint_prefix,
                                   global_step=current_step)

                        best_score = official_score
                        best_step = current_step
                        best_itr = ite
                    print("Best Score: %2.3f, Best Step: %d (iteration: %d)" %
                          (best_score, best_step, best_itr))