示例#1
0
文件: run.py 项目: qiuchili/monosenti
def run(params):
    model = None
    if 'load_model_from_dir' in params.__dict__ and params.load_model_from_dir:
        print('Loading the model from an existing dir!')
        model_params = pickle.load(
            open(os.path.join(params.dir_name, 'config.pkl'), 'rb'))
        if 'lookup_table' in params.__dict__:
            model_params.lookup_table = params.lookup_table
        if 'sentiment_dic' in params.__dict__:
            model_params.sentiment_dic = params.sentiment_dic
        model = models.setup(model_params)
        model.load_state_dict(
            torch.load(os.path.join(params.dir_name, 'model')))
        model = model.to(params.device)
    else:
        model = models.setup(params).to(params.device)

    if not ('fine_tune' in params.__dict__ and params.fine_tune == False):
        print('Training the model!')
        train(params, model)
        model = torch.load(params.best_model_file)
        os.remove(params.best_model_file)

    performance_dict = test(model, params)
    performance_str = print_performance(performance_dict, params)
    save_model(model, params, performance_str)

    return performance_dict
def bilstm_train_and_eval(train_loader, dev_loader, eval_loader,
                          test_loader, token2id, tag2id, method):
    """训练并保存模型"""

    vocab_size = len(token2id)
    out_size = len(tag2id)
    meta = get_meta([TrainingConfig.__dict__, LSTMConfig.__dict__])

    model = BILSTM_Model(vocab_size, out_size, token2id, tag2id, method=method)
    model.train(train_loader, dev_loader, eval_loader)

    try:
        # 保存模型的信息
        root_dir = "/home/luopx/share_folders/Sohu"
        model_dir = 'ckpts/{}/{}-{}-Len{}-{:.2f}-{:.4f}'.format(
            model.method,
            meta['token_method'],
            meta['tag_schema'],
            meta['max_len'],
            model.best_val_loss,
            model.best_f1_score
        )
        model_dir = join(root_dir, model_dir)

        if not os.path.isdir(model_dir):
            os.mkdir(model_dir)
        save_model(model, join(model_dir, "model.pkl"))

        # 保存word2id  tag2id 以及模型设置的信息
        with open(join(model_dir, 'meta.json'), 'w') as w:
            w.write(json.dumps(meta, indent=4))

        # 在验证集上面观察模型的效果、特点
        print("评估{}模型中...".format(method))
        # 分析结果
        print("分析在验证集上的结果...")
        metrics = model.cal_scores(eval_loader, use_model='best_f1')
        with open(join(model_dir, 'dev_result.txt'), 'w') as outfile:
            metrics.report_details(outfile=outfile)

        # 加载测试集,解码,将结果保存成文件
        print("在val_loss最小的模型上解码...")
        test_result = join(model_dir, 'min_devLoss_result.txt')
        decoding(model, test_loader, test_result)
        print("在f1分数值最大的模型上解码...")
        test_result = join(model_dir, 'max_f1_result.txt')
        decoding(model, test_loader, test_result, use_model="best_f1")

    except:
        import pdb
        pdb.set_trace()
def run(args, i_cv):
    logger = logging.getLogger()
    print_line()
    logger.info('Running iter n°{}'.format(i_cv))
    print_line()

    result_row = {'i_cv': i_cv}
    result_table = []

    # LOAD/GENERATE DATA
    logger.info('Set up data generator')
    pb_config = S3D2Config()
    seed = config.SEED + i_cv * 5
    train_generator = S3D2(seed)
    valid_generator = S3D2(seed + 1)
    test_generator = S3D2(seed + 2)

    # SET MODEL
    logger.info('Set up rergessor')
    args.net = AR5R5E(n_in=3, n_out=2, n_extra=2)
    args.optimizer = get_optimizer(args)
    model = get_model(args, Regressor)
    model.set_info(BENCHMARK_NAME, i_cv)
    model.param_generator = param_generator
    flush(logger)

    # TRAINING / LOADING
    if not args.retrain:
        try:
            logger.info('loading from {}'.format(model.model_path))
            model.load(model.model_path)
        except Exception as e:
            logger.warning(e)
            args.retrain = True
    if args.retrain:
        logger.info('Training {}'.format(model.get_name()))
        model.fit(train_generator)
        logger.info('Training DONE')

        # SAVE MODEL
        save_model(model)

    # CHECK TRAINING
    logger.info('Plot losses')
    plot_REG_losses(model)
    plot_REG_log_mse(model)
    result_row['loss'] = model.losses[-1]
    result_row['mse_loss'] = model.mse_losses[-1]

    # MEASUREMENT
    for mu in pb_config.TRUE_MU_RANGE:
        pb_config.TRUE_MU = mu
        logger.info('Generate testing data')
        test_generator.reset()
        X_test, y_test, w_test = test_generator.generate(
            # pb_config.TRUE_R,
            # pb_config.TRUE_LAMBDA,
            pb_config.CALIBRATED_R,
            pb_config.CALIBRATED_LAMBDA,
            pb_config.TRUE_MU,
            n_samples=pb_config.N_TESTING_SAMPLES)

        p_test = np.array(
            (pb_config.CALIBRATED_R, pb_config.CALIBRATED_LAMBDA))

        pred, sigma = model.predict(X_test, w_test, p_test)
        name = pb_config.INTEREST_PARAM_NAME
        result_row[name] = pred
        result_row[name + _ERROR] = sigma
        result_row[name + _TRUTH] = pb_config.TRUE_MU
        logger.info('{} =vs= {} +/- {}'.format(pb_config.TRUE_MU, pred, sigma))
        result_table.append(result_row.copy())
    result_table = pd.DataFrame(result_table)

    logger.info('Plot params')
    name = pb_config.INTEREST_PARAM_NAME
    plot_params(name,
                result_table,
                title=model.full_name,
                directory=model.results_path)

    logger.info('DONE')
    return result_table
示例#4
0
alpha = [1, 0.1, .01, 10]  # dirichlet parameter
dirname = 'test_batch_result/'
for k in K:
    for a in alpha:
        print 'Number of terms: %d' % V
        # print 'Number of documents: %d' % len(W_tr)
        print 'Number of documents: %d' % len(W)

        # Model
        lda = LDA_VB(a)
        lda.set_params(K=k, V=V, log=dirname + 'lda_log' + str(count) + '.txt')

        # Fitting
        # lda.fit(W_tr)
        lda.fit(W)

        # Result
        top_idxs = lda.get_top_words_indexes()
        # perplexity = lda.perplexity(W_test)
        with open(dirname + 'lda_result' + str(count) + '.txt', 'w') as f:
            # s = 'Perplexity: %f' % perplexity
            # f.write(s)
            for i in range(len(top_idxs)):
                s = '\nTopic %d:' % i
                for idx in top_idxs[i]:
                    s += ' %s' % inv_dic[idx]
                f.write(s)

        # Save model
        save_model(lda, dirname + 'model' + str(count) + '.csv')
        count += 1
示例#5
0
def run(args, i_cv):
    logger = logging.getLogger()
    print_line()
    logger.info('Running iter n°{}'.format(i_cv))
    print_line()

    result_row = {'i_cv': i_cv}
    result_table = []

    # LOAD/GENERATE DATA
    logger.info('Set up data generator')
    pb_config = AP1Config()
    seed = config.SEED + i_cv * 5
    train_generator = Generator(param_generator, AP1(seed))
    valid_generator = AP1(seed + 1)
    test_generator = AP1(seed + 2)

    # SET MODEL
    logger.info('Set up rergessor')
    args.net = F3R3(n_in=1, n_out=2)
    args.optimizer = get_optimizer(args)
    model = get_model(args, Regressor)
    model.set_info(BENCHMARK_NAME, i_cv)
    flush(logger)

    # TRAINING / LOADING
    if not args.retrain:
        try:
            logger.info('loading from {}'.format(model.path))
            model.load(model.path)
        except Exception as e:
            logger.warning(e)
            args.retrain = True
    if args.retrain:
        logger.info('Training {}'.format(model.get_name()))
        model.fit(train_generator)
        logger.info('Training DONE')

        # SAVE MODEL
        save_model(model)

    # CHECK TRAINING
    logger.info('Plot losses')
    plot_REG_losses(model)
    plot_REG_log_mse(model)
    result_row['loss'] = model.losses[-1]
    result_row['mse_loss'] = model.mse_losses[-1]

    # MEASUREMENT
    for mu in pb_config.TRUE_APPLE_RATIO_RANGE:
        pb_config.TRUE_APPLE_RATIO = mu
        logger.info('Generate testing data')
        X_test, y_test, w_test = test_generator.generate(
            apple_ratio=pb_config.TRUE_APPLE_RATIO,
            n_samples=pb_config.N_TESTING_SAMPLES)

        pred, sigma = model.predict(X_test, w_test)
        name = pb_config.INTEREST_PARAM_NAME
        result_row[name] = pred
        result_row[name + _ERROR] = sigma
        result_row[name + _TRUTH] = pb_config.TRUE_APPLE_RATIO

        logger.info('{} =vs= {} +/- {}'.format(pb_config.TRUE_APPLE_RATIO,
                                               pred, sigma))
        result_table.append(result_row.copy())
    result_table = pd.DataFrame(result_table)

    logger.info('Plot params')
    param_names = pb_config.PARAM_NAMES
    for name in param_names:
        plot_params(name, result_table, model)

    logger.info('DONE')
    return result_table
示例#6
0
def run(args, i_cv):
    logger = logging.getLogger()
    print_line()
    logger.info('Running iter n°{}'.format(i_cv))
    print_line()
    
    result_row = {'i_cv': i_cv}
    result_table = []

    # LOAD/GENERATE DATA
    logger.info('Set up data generator')
    pb_config = AP1Config()
    seed = config.SEED + i_cv * 5
    train_generator = AP1(seed)
    valid_generator = AP1(seed+1)
    test_generator  = AP1(seed+2)

    # SET MODEL
    logger.info('Set up classifier')
    model = get_model(args, GradientBoostingModel)
    model.set_info(BENCHMARK_NAME, i_cv)
    flush(logger)

    # TRAINING / LOADING
    if not args.retrain:
        try:
            logger.info('loading from {}'.format(model.path))
            model.load(model.path)
        except Exception as e:
            logger.warning(e)
            args.retrain = True
    if args.retrain:
        logger.info('Generate training data')
        X_train, y_train, w_train = train_generator.generate(
                                        apple_ratio=pb_config.CALIBRATED_APPLE_RATIO,
                                        n_samples=pb_config.N_TRAINING_SAMPLES)
        logger.info('Training {}'.format(model.get_name()))
        model.fit(X_train, y_train, w_train)
        logger.info('Training DONE')

        # SAVE MODEL
        save_model(model)


    # CHECK TRAINING
    logger.info('Generate validation data')
    X_valid, y_valid, w_valid = valid_generator.generate(
                                    apple_ratio=pb_config.CALIBRATED_APPLE_RATIO,
                                    n_samples=pb_config.N_VALIDATION_SAMPLES)

    logger.info('Plot distribution of the score')
    plot_valid_distrib(model, X_valid, y_valid, classes=("pears", "apples"))
    result_row['valid_accuracy'] = model.score(X_valid, y_valid)


    # MEASUREMENT
    n_bins = 10
    compute_summaries = ClassifierSummaryComputer(model, n_bins=n_bins)
    for mu in pb_config.TRUE_APPLE_RATIO_RANGE:
        pb_config.TRUE_APPLE_RATIO = mu
        logger.info('Generate testing data')
        X_test, y_test, w_test = test_generator.generate(
                                        apple_ratio=pb_config.TRUE_APPLE_RATIO,
                                        n_samples=pb_config.N_TESTING_SAMPLES)
        
        logger.info('Set up NLL computer')
        compute_nll = AP1NLL(compute_summaries, valid_generator, X_test, w_test)

        logger.info('Plot summaries')
        extension = '-mu={:1.1f}'.format(pb_config.TRUE_APPLE_RATIO)
        plot_summaries( model, n_bins, extension,
                        X_valid, y_valid, w_valid,
                        X_test, w_test, classes=('pears', 'apples', 'fruits') )

        # NLL PLOTS
        logger.info('Plot NLL around minimum')
        plot_apple_ratio_around_min(compute_nll, 
                                    pb_config.TRUE_APPLE_RATIO,
                                    model,
                                    extension)

        # MINIMIZE NLL
        logger.info('Prepare minuit minimizer')
        minimizer = get_minimizer(compute_nll)
        fmin, params = estimate(minimizer)
        params_truth = [pb_config.TRUE_APPLE_RATIO]

        print_params(params, params_truth)
        register_params(params, params_truth, result_row)
        result_row['is_mingrad_valid'] = minimizer.migrad_ok()
        result_row.update(fmin)
        result_table.append(result_row.copy())
    result_table = pd.DataFrame(result_table)

    logger.info('Plot params')
    param_names = pb_config.PARAM_NAMES
    for name in param_names:
        plot_params(name, result_table, title=model.full_name, directory=model.path)

    logger.info('DONE')
    return result_table
示例#7
0
def train(cfg):
    Dataset = load_dataset(cfg.dataset)
    train_dataset = Dataset('train', cfg)
    val_dataset = Dataset('val', cfg)
    cfg = Config().update_dataset_info(cfg, train_dataset)
    Config().print(cfg)
    logger = Logger(cfg)

    model = SqueezeDetWithLoss(cfg)
    if cfg.load_model != '':
        if cfg.load_model.endswith('f364aa15.pth') or cfg.load_model.endswith(
                'a815701f.pth'):
            model = load_official_model(model, cfg.load_model)
        else:
            model = load_model(model, cfg.load_model)

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=cfg.lr,
                                momentum=cfg.momentum,
                                weight_decay=cfg.weight_decay)
    lr_scheduler = StepLR(optimizer, 60, gamma=0.5)

    trainer = Trainer(model, optimizer, lr_scheduler, cfg)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=cfg.batch_size,
                                               num_workers=cfg.num_workers,
                                               pin_memory=True,
                                               shuffle=True,
                                               drop_last=True)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=cfg.batch_size,
                                             num_workers=cfg.num_workers,
                                             pin_memory=True)

    metrics = trainer.metrics if cfg.no_eval else trainer.metrics + ['mAP']
    best = 1E9 if cfg.no_eval else 0
    better_than = operator.lt if cfg.no_eval else operator.gt

    for epoch in range(1, cfg.num_epochs + 1):
        train_stats = trainer.train_epoch(epoch, train_loader)
        logger.update(train_stats, phase='train', epoch=epoch)

        save_path = os.path.join(cfg.save_dir, 'model_last.pth')
        save_model(model, save_path, epoch)

        if epoch % cfg.save_intervals == 0:
            save_path = os.path.join(cfg.save_dir,
                                     'model_{}.pth'.format(epoch))
            save_model(model, save_path, epoch)

        if cfg.val_intervals > 0 and epoch % cfg.val_intervals == 0:
            val_stats = trainer.val_epoch(epoch, val_loader)
            logger.update(val_stats, phase='val', epoch=epoch)

            if not cfg.no_eval:
                aps = eval_dataset(val_dataset, save_path, cfg)
                logger.update(aps, phase='val', epoch=epoch)

            value = val_stats['loss'] if cfg.no_eval else aps['mAP']
            if better_than(value, best):
                best = value
                save_path = os.path.join(cfg.save_dir, 'model_best.pth')
                save_model(model, save_path, epoch)

        logger.plot(metrics)
        logger.print_bests(metrics)

    torch.cuda.empty_cache()
def train(model, criterion, optimizer, train_loader, val_loader, args):
    best_prec1 = 0
    epoch_no_improve = 0

    for epoch in range(1000):

        statistics = Statistics()
        model.train()
        start_time = time.time()

        for i, (input, target) in enumerate(train_loader):
            loss, (prec1, prec5), y_pred, y_true = execute_batch(
                model, criterion, input, target, args.device)

            statistics.update(loss.detach().cpu().numpy(), prec1, prec5,
                              y_pred, y_true)
            # compute gradient and do optimizer step
            optimizer.zero_grad()  #
            loss.backward()
            optimizer.step()

            # if args.net_version == 2:
            #    model.camera_position = model.camera_position.clamp(0, 1)
            del loss
            torch.cuda.empty_cache()

        elapsed_time = time.time() - start_time

        # Evaluate on validation set
        val_statistics = validate(val_loader, model, criterion, args.device)

        log_data(statistics, "train", val_loader.dataset.dataset.classes,
                 epoch)
        log_data(val_statistics, "internal_val",
                 val_loader.dataset.dataset.classes, epoch)

        wandb.log({"Epoch elapsed time": elapsed_time}, step=epoch)
        # print(model.camera_position)
        if epoch % 1 == 0:
            vertices = []
            if args.net_version == 1:
                R = look_at_rotation(model.camera_position, device=args.device)
                T = -torch.bmm(R.transpose(1, 2),
                               model.camera_position[:, :, None])[:, :, 0]
            else:
                t = Transform3d(device=model.device).scale(
                    model.camera_position[3] *
                    model.distance_range).rotate_axis_angle(
                        model.camera_position[0] * model.angle_range,
                        axis="X",
                        degrees=False).rotate_axis_angle(
                            model.camera_position[1] * model.angle_range,
                            axis="Y",
                            degrees=False).rotate_axis_angle(
                                model.camera_position[2] * model.angle_range,
                                axis="Z",
                                degrees=False)

                vertices = t.transform_points(model.vertices)

                R = look_at_rotation(vertices[:model.nviews],
                                     device=model.device)
                T = -torch.bmm(R.transpose(1, 2), vertices[:model.nviews, :,
                                                           None])[:, :, 0]

            cameras = OpenGLPerspectiveCameras(R=R, T=T, device=args.device)
            wandb.log(
                {
                    "Cameras":
                    [wandb.Image(plot_camera_scene(cameras, args.device))]
                },
                step=epoch)
            plt.close()
            images = render_shape(model, R, T, args, vertices)
            wandb.log(
                {
                    "Views": [
                        wandb.Image(
                            image_grid(images,
                                       rows=int(np.ceil(args.nviews / 2)),
                                       cols=2))
                    ]
                },
                step=epoch)
            plt.close()
        #  Save best model and best prediction
        if val_statistics.top1.avg > best_prec1:
            best_prec1 = val_statistics.top1.avg
            save_model("views_net", model, optimizer, args.fname_best)
            epoch_no_improve = 0
        else:
            # Early stopping
            epoch_no_improve += 1
            if epoch_no_improve == 20:
                wandb.run.summary[
                    "best_internal_val_top1_accuracy"] = best_prec1
                wandb.run.summary[
                    "best_internal_val_top1_accuracy_epoch"] = epoch - 20

                return
示例#9
0
def train(model, criterion, optimizer, train_loader, val_loader, args):
    """
    Train the model on the train data loader data and stop when the top1 precision did not increased for args.patience
    epochs. The early stopping is done on the validation data loader.

    All the data are sent to wandb or logged on a text file. More precisely for each epoch the validation and training
    performance are sent to wandb and every args.print_freq the batch performance are logged on a file. At the end of
    the training the best top1 validation accuracy is sent to wandb.

    Parameters
    ----------
    model : RotaitonNet model
    criterion : Pytorch criterion (CrossEntropy for RotationNet)
    optimizer : Pytorch optimizer (e.g. SGD)
    train_loader : Data loader with training data (this must be created with a subset)
    val_loader : Data loader with validation data for early stopping (this must be created with a subset)
    args : Input args from the parser

    Returns
    -------
    Nothing
    """
    # Best prediction
    best_prec1 = 0
    # Using lr_scheduler for learning rate decay
    # scheduler = StepLR(optimizer, step_size=args.learning_rate_decay, gamma=0.1)

    epoch_no_improve = 0

    for epoch in range(args.epochs):
        # Give random permutation to the images
        indices = train_loader.dataset.indices
        inds = random_permute_indices(np.array(indices), args.nview, False)
        train_loader.dataset.indices = np.asarray(inds)
        del indices
        del inds

        statistics = Statistics()

        # switch to train mode
        model.train()
        start_time = time.time()
        for batch_idx, (input_val, target_val) in enumerate(train_loader):
            # loss, (prec1, prec5), y_pred, y_true = execute_batch(model, criterion, input_val,
            #                                                     target_val, args)
            loss, (prec1, prec5), y_pred, y_true = execute_batch_aligned(model, criterion, input_val,
                                                                         target_val, args)

            statistics.update(loss.detach().cpu().numpy(), prec1, prec5, y_pred, y_true)
            # compute gradient and do optimizer step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % args.print_freq == 0:
                logger.debug('Batch: [{0}/{1}]\t'
                             'Loss {loss:.4f} \t'
                             'Prec@1 {top1:.3f} \t'
                             'Prec@5 {top5:.3f}'.format(batch_idx, len(train_loader), loss=loss.data, top1=prec1,
                                                        top5=prec5))
            del loss
            torch.cuda.empty_cache()

        elapsed_time = time.time() - start_time

        logger.debug("Evaluating epoch {}".format(epoch))

        # permute indices
        # indices = val_loader.dataset.indices
        # inds = random_permute_indices(np.array(indices), args.nview)
        # val_loader.dataset.indices = np.asarray(inds)
        # del indices
        # del inds
        # Evaluate on validation set
        val_statistics = validate(val_loader, model, criterion, args)

        # statistics.compute(args.num_classes)
        # val_statistics.compute(args.num_classes)

        log_data(statistics, "train", val_loader.dataset.dataset.classes, epoch)
        log_data(val_statistics, "internal_val", val_loader.dataset.dataset.classes, epoch)

        wandb.log({"Epoch elapsed time": elapsed_time}, step=epoch)

        #  Save best model and best prediction
        if val_statistics.top1.avg > best_prec1:
            best_prec1 = val_statistics.top1.avg
            save_model(args.arch, model, optimizer, args.fname_best)
            epoch_no_improve = 0
        else:
            # Early stopping
            epoch_no_improve += 1
            if epoch_no_improve == args.patience:
                wandb.run.summary["best_internal_val_top1_accuracy"] = best_prec1
                wandb.run.summary["best_internal_val_top1_accuracy_epoch"] = epoch - args.patience

                logger.debug("Stopping at epoch {} for early stopping (best was at epoch {})".format(epoch,
                                                                                                     epoch - args.patience))
                return
示例#10
0
def train_k_fold(model, criterion, optimizer, args):
    """
    Perform a k-fold cross validation on the given model.  Use args.fold to set the number of total folds and
    args.ntest_fold to set the number of folds dedicated for the validation. Keep in mind that each training step uses
    also an extra validation set for early stopping (the size of this subset is set via args.val_split_size).

    All the data are sent to wandb or logged on a text file

    Parameters
    ----------
    model : RotaitonNet model
    criterion : Pytorch criterion (CrossEntropy for RotationNet)
    optimizer : Pytorch optimizer (e.g. SGD)
    args : Input args from the parser

    Returns
    -------
    Nothing
    """
    # Save clean model to reload at each fold
    save_model(args.arch, model, optimizer, args.fname)

    # Get full train dataset
    full_set = generate_dataset(args.data, 'train')

    # Get folds
    folds = k_fold(full_set[0].samples, args.nview, True, args.fold)

    # List of top1 and top5 accuracies to get the average and std of the model performance
    top1 = []
    top5 = []

    # K-fold cross validation
    for i in range(args.fold):
        test_idx = []
        # Use ntest_folds folds for the test set
        for j in range(args.ntest_folds):
            test_idx.extend(folds[i])
            folds = np.delete(folds, i, 0)

        # Use rest of the data for the train set
        train_idx = np.hstack(folds)
        val_idx, train_idx = train_test_split([full_set[0].samples[i] for i in train_idx], train_idx, args.nview,
                                              args.val_split_size, True)

        # Get subsets
        test_set = torch.utils.data.Subset(full_set[0], test_idx)
        val_set = torch.utils.data.Subset(full_set[0], val_idx)
        train_set = torch.utils.data.Subset(full_set[0], train_idx)

        # Generate loaders
        test_loader = generate_loader(test_set, args.batch_size, args.workers)
        val_loader = generate_loader(val_set, args.batch_size, args.workers)
        train_loader = generate_loader(train_set, args.batch_size, args.workers)

        logger.debug("Start train on fold {}/{}".format(i, args.fold))

        # Track model in wandb
        wandb.init(project="RotationNet", name="Fold " + str(i), config=args, reinit=True)

        # The model can be analyzed only once
        # if i == 0:
        #    wandb.watch(model)

        train(model, criterion, optimizer, train_loader, val_loader, args)

        # Load best model before validating
        load_model(model, args.fname_best)

        val_statistics = validate(test_loader, model, criterion, args)
        log_summary(val_statistics, "val", test_loader.dataset.dataset.classes)

        # Load fresh model for next train
        load_model(model, args.fname)

    logger.info('Val prec@1 {top1:.3f} +- {std1:.3f} \t'
                'Val prec@5 {top5:.3f} +- {std5:.3f} \t'.format(top1=np.mean(top1), std1=np.std(top1),
                                                                top5=np.mean(top5), std5=np.std(top5)))
示例#11
0
def train(model, criterion, optimizer, train_loader, val_loader, args,
          single_view):
    # Best prediction
    best_prec1 = 0

    epoch_no_improve = 0

    for epoch in range(args.epochs):
        if not single_view:
            # Give random permutation to the images
            indices = train_loader.dataset.indices
            inds = random_permute_indices(np.array(indices), args.nview)
            train_loader.dataset.indices = np.asarray(inds)
            del indices
            del inds

        statistics = Statistics()

        # switch to train mode
        model.train()
        start_time = time.time()
        for batch_idx, (input_val, target_val) in enumerate(train_loader):
            loss, (prec1, prec5), y_pred, y_true = execute_batch(
                model, criterion, input_val, target_val, args, single_view)

            statistics.update(loss.detach().cpu().numpy(), prec1, prec5,
                              y_pred, y_true)
            # compute gradient and do optimizer step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            del loss
            torch.cuda.empty_cache()

        elapsed_time = time.time() - start_time

        # Evaluate on validation set
        val_statistics = validate(val_loader, model, criterion, args,
                                  single_view)

        # statistics.compute(args.num_classes)
        # val_statistics.compute(args.num_classes)

        log_data(statistics, "train", val_loader.dataset.dataset.classes,
                 epoch)
        log_data(val_statistics, "internal_val",
                 val_loader.dataset.dataset.classes, epoch)

        wandb.log({"Epoch elapsed time": elapsed_time}, step=epoch)

        #  Save best model and best prediction
        if val_statistics.top1.avg > best_prec1:
            best_prec1 = val_statistics.top1.avg
            save_model(args.arch, model, optimizer, args.fname_best)
            epoch_no_improve = 0
        else:
            # Early stopping
            epoch_no_improve += 1
            if epoch_no_improve == args.patience:
                wandb.run.summary[
                    "best_internal_val_top1_accuracy"] = best_prec1
                wandb.run.summary[
                    "best_internal_val_top1_accuracy_epoch"] = epoch - args.patience

                return