Ejemplo n.º 1
0
def train(cfg, args):
    # set default device
    device = torch.device(cfg.MODEL.DEVICE)
    # build Butterfly Net as [model]
    model = build_model(cfg, name="Btrfly").to(device)

    #build discriminator nets as [model_D1] and [model_D2] if necessary
    model_D1, model_D2 = None, None
    if cfg.MODEL.USE_GAN:
        model_D1 = build_model(cfg, name="EBGAN").to(device)
        model_D2 = build_model(cfg, name="EBGAN").to(device)
        print(model_D1)

    #if you need to visualize the Net, uncomment these codes
    """
    input1 = torch.rand(3, 1, 128, 128)  
    input2 = torch.rand(3, 1, 128, 128)
    with SummaryWriter(comment='BtrflyNet') as w:
        w.add_graph(model, (input1, input2, ))
    """

    # learning rate
    lr = cfg.SOLVER.LR
    # optimizer of [model]
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    #optimizers of [model_D1] and [model D2] if necessary
    optimizer_D1, optimizer_D2 = None, None
    if cfg.MODEL.USE_GAN:
        optimizer_D1 = torch.optim.Adam(model_D1.parameters(), lr=lr)
        optimizer_D2 = torch.optim.Adam(model_D2.parameters(), lr=lr)

    # update [checkpointer] if necessary
    # except iteration and epoch numbers,
    # [arguments] also has a list which contains the information of the best several models,
    # including their numbers and their validation losses
    arguments = {"iteration": 0, "epoch": 0, "list_loss_val": {}}
    checkpointer = CheckPointer(model, optimizer, cfg.OUTPUT_DIR)
    extra_checkpoint_data = checkpointer.load()
    arguments.update(extra_checkpoint_data)

    # build training set from the directory designated by cfg
    dataset = ProjectionDataset(cfg=cfg,
                                mat_dir=cfg.MAT_DIR_TRAIN,
                                input_img_dir=cfg.INPUT_IMG_DIR_TRAIN,
                                transform=transforms.Compose([ToTensor()]),
                                )
    train_loader = DataLoader(dataset, batch_size=cfg.SOLVER.BATCH_SIZE, shuffle=True, num_workers=4)


    return do_train(cfg, args, model, model_D1, model_D2, train_loader, optimizer, optimizer_D1, optimizer_D2, checkpointer, device, arguments)
Ejemplo n.º 2
0
def plot_dual_attention():
    """Plot dual attention vectors of 2 models over given context."""
    modelbw = build_model("imasm",
                          ARGS.model_dir + "curr_imasm64.h5",
                          char_size=len(CHAR_IDX) + 1,
                          dim=ARGS.dim,
                          iterations=ARGS.iterations,
                          training=False)
    modelfw = build_model("fwimarsm",
                          ARGS.model_dir + "curr_fwimarsm64.h5",
                          char_size=len(CHAR_IDX) + 1,
                          dim=ARGS.dim,
                          iterations=ARGS.iterations,
                          training=False)
    ctxs = [
        "p(X):-q(X).q(X):-r(X).r(X):-s(X).s(a).s(b).",
        "p(X):-q(X);r(X).r(a).q(a).r(b).q(b).",
        "p(X):-q(X).p(X):-r(X).p(b).r(a).q(b)."
    ]
    fig, axes = plt.subplots(1, 6)
    # Plot the attention
    for i, ctx in enumerate(ctxs):
        for j, m in enumerate([modelbw, modelfw]):
            rs = ctx.split('.')[:-1]
            dgen = LogicSeq([[([r + '.' for r in rs], "p(a).", 0)]],
                            1,
                            False,
                            False,
                            pad=ARGS.pad)
            out = m.predict_generator(dgen)
            sims = out[:-1]
            out = np.round(np.asscalar(out[-1]), 2)
            sims = np.stack(sims, axis=0).squeeze()
            sims = sims.T
            ticks = (["()"] if ARGS.pad else []) + ["$\phi$"]
            axes[i * 2 + j].get_xaxis().set_ticks_position('top')
            sns.heatmap(sims,
                        vmin=0,
                        vmax=1,
                        cmap="Blues",
                        yticklabels=rs + ticks if j % 2 == 0 else False,
                        xticklabels=range(1, 5),
                        linewidths=0.5,
                        square=True,
                        cbar=False,
                        ax=axes[i * 2 + j])
            # axes[i*2+j].set_xlabel("p(Q|C)=" + str(out))
            axes[i * 2 + j].set_xlabel("backward" if j % 2 == 0 else "forward")
    plt.tight_layout()
    showsave_plot()
Ejemplo n.º 3
0
  def __init__(self, pretrained_model_conf, learnable_model_conf,
               mode='add', input_mode='input', mse_path_model_conf=None,
               freeze_pretrained_model=True, disable_strict_loading=False):
    super(RefinementWrapper, self).__init__()
    self.mode = mode
    self.freeze_pretrained_model = freeze_pretrained_model
    self.pretrained_model = build_model(pretrained_model_conf,
                                        pretrained_model_conf.name)
    self.learnable_model = build_model(learnable_model_conf,
                                       learnable_model_conf.name)

    if mode == 'add':
      self._refine_op = self._refinement_add
    elif mode == 'real-penalty-add':
      self.scale = nn.Parameter(torch.zeros(1))
      self._refine_op = self._refinement_real_penalty_add
    else:
      raise ValueError('Unknown mode {}'.format(mode))

    if input_mode == 'input':
      self._learnable_model_input_fn = lambda inp, out: inp
    elif input_mode == 'output':
      self._learnable_model_input_fn = lambda inp, out: out
    elif input_mode == 'concat':
      self._learnable_model_input_fn = lambda inp, out: torch.cat((inp, out),
                                                                  dim=1)
    else:
      raise ValueError('Unknown input mode {}'.format(mode))

    # As models can have different arguments on their forward function,
    # we need to dynamically select a forward function which fits the
    # signature of the pretrained model. For now we assume that the learnable
    # model always has one input, and that only the pretrained model can have
    # different signatures.
    forward_replacements = [
        self._forward_vanilla,
        self._forward_reconstruction
    ]

    signature_pretrained = inspect.signature(self.pretrained_model.forward)
    params_pretrained = signature_pretrained.parameters
    for forward_fn in forward_replacements:
      if params_pretrained == inspect.signature(forward_fn).parameters:
        self.forward = forward_fn
        break
    else:
      raise RuntimeError(('Could not find fitting forward '
                          'function with params {}').format(params_pretrained))
Ejemplo n.º 4
0
    def __init__(self, model_path, gpu_id=0):
        from models import build_model
        from data_loader import get_dataloader
        from post_processing import get_post_processing
        from utils import get_metric
        self.gpu_id = gpu_id
        if self.gpu_id is not None and isinstance(
                self.gpu_id, int) and torch.cuda.is_available():
            self.device = torch.device("cuda:%s" % self.gpu_id)
            torch.backends.cudnn.benchmark = True
        else:
            self.device = torch.device("cpu")
        print('load model:', model_path)
        checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
        config = checkpoint['config']
        config['arch']['backbone']['pretrained'] = False

        self.validate_loader = get_dataloader(config['dataset']['validate'],
                                              config['distributed'])

        self.model = build_model(config['arch'].pop('type'), **config['arch'])
        self.model.load_state_dict(checkpoint['state_dict'])
        self.model.to(self.device)

        self.post_process = get_post_processing(config['post_processing'])
        self.metric_cls = get_metric(config['metric'])
Ejemplo n.º 5
0
def main():
    import sys
    import pathlib

    __dir__ = pathlib.Path(os.path.abspath(__file__))
    sys.path.append(str(__dir__))
    sys.path.append(str(__dir__.parent.parent))

    from models import build_model, build_loss
    from data_loader import get_dataloader
    from utils import Trainer
    from utils import get_post_processing
    from utils import get_metric

    config = anyconfig.load(open('config.yaml', 'rb'))
    train_loader = get_dataloader(config['dataset']['train'])
    validate_loader = get_dataloader(config['dataset']['validate'])
    criterion = build_loss(config['loss']).cuda()
    model = build_model(config['arch'])
    post_p = get_post_processing(config['post_processing'])
    metric = get_metric(config['metric'])

    trainer = Trainer(config=config,
                      model=model,
                      criterion=criterion,
                      train_loader=train_loader,
                      post_process=post_p,
                      metric_cls=metric,
                      validate_loader=validate_loader)
    trainer.train()
Ejemplo n.º 6
0
def main(args):
    cfg = Config.fromfile(args.config)
    for d in [cfg, cfg.data.test]:
        d.update(dict(report_speed=args.report_speed))
    print(json.dumps(cfg._cfg_dict, indent=4))
    sys.stdout.flush()

    device = paddle.get_device()
    paddle.set_device(device)

    # model
    model = build_model(cfg.model)

    if args.checkpoint is not None:
        if os.path.isfile(args.checkpoint):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.checkpoint))
            sys.stdout.flush()

            checkpoint = paddle.load(args.checkpoint)
            model.set_state_dict(checkpoint)
        else:
            print("No checkpoint found at '{}'".format(args.resume))
            raise

    # fuse conv and bn
    model = fuse_module(model)

    # test
    predict(args.input, model, cfg, args.output)
Ejemplo n.º 7
0
def evaluate(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    preprocessor = torch.load(args.preprocessor_file)
    loader_config = dict(
        preprocessor=preprocessor,
        batch_size=args.batch_size,
        device=device,
    )
    eval_dataloader = create_dataloader(args.eval_file, **loader_config, shuffle=False)

    checkpoint = torch.load(args.checkpoint_file)
    model = build_model(
        word_vocab_size=len(preprocessor.vocabs["word"]),
        pretrained_word_vocab_size=len(preprocessor.vocabs["pretrained_word"]),
        postag_vocab_size=len(preprocessor.vocabs["postag"]),
        n_deprels=len(preprocessor.vocabs["deprel"]),
    )
    model.load_state_dict(checkpoint["model"])
    model.to(device)

    trainer = create_trainer(model)
    trainer.add_callback(utils.training.PrintCallback(printer=logger.info))
    deprel_map = {v: k for k, v in preprocessor.vocabs["deprel"].mapping.items()}
    trainer.add_callback(EvaluateCallback(args.eval_file, deprel_map, args.verbose), priority=0)
    with logging_redirect_tqdm(loggers=[logger]):
        trainer.evaluate(eval_dataloader)
Ejemplo n.º 8
0
def main():
    args = get_arguments()

    print('Loading data...\n    DATABASE: {}'.format(args.database))
    X, Y, category_names = load_data(args.database)
    X_train, X_test, Y_train, Y_test = train_test_split(X,
                                                        Y,
                                                        test_size=0.2,
                                                        random_state=1)
    print("Will train with {} training data and validate with "
          "{} validation data".format(X_train.shape[0], X_test.shape[0]))

    print('Building model...')
    model = build_model()

    print('Training model...')
    start = time()
    model.fit(X_train, Y_train)
    print('Done training in {:.3f}s'.format(time() - start))

    # use the best estimator only
    print("Best parameters..")
    print(model.best_params_)
    model = model.best_estimator_

    print('Evaluating model...')
    evaluate_model(model, X_test, Y_test, category_names)

    print('Saving model...\n    MODEL: {}'.format(args.model))
    save_model(model, args.model)

    print('Trained model saved!')
Ejemplo n.º 9
0
def main(use_cuda, json_path, pretrained_model, args):
    data_set = HitachiDataset(json_path)
    test_loader = DataLoader(data_set, 1, shuffle=False)

    model = build_model(args.model_name, args)

    # Whether using checkpoint
    checkpoint = torch.load(pretrained_model)
    state_dict = checkpoint['state_dict']
    model.load_state_dict(state_dict)
    model.cuda()

    model = nn.DataParallel(model, device_ids=[int(e) for e in args.gpu_ids]) if args.use_cuda else model
    model.eval()
    with torch.no_grad():
        for k, (val_img, img_path) in enumerate(tqdm(test_loader)):
            # if k > 0: break
            if use_cuda:
                val_img = val_img.cuda()

            pred_depth = model(val_img)

            pred_depth = pred_depth.cpu().numpy().squeeze()
            pred_depth_scale = (pred_depth * 1000)

            save_name = 'Depth_' + img_path.split('/')[-1].split('_')[-1].replace('jpg', 'png')
            cv2.imwrite(os.path.join('result', save_name), pred_depth_scale, [cv2.IMWRITE_PNG_COMPRESSION, 0])
Ejemplo n.º 10
0
def test_model():
    model_ = models.build_model()
    datas, label = data.read_test_data()
    label = label.detach().numpy()
    y = model_(datas).detach().numpy()
    # print(y.shape, label.shape)
    print(tool.score(label, y))
Ejemplo n.º 11
0
def build():
    
    db.graph.run("MATCH (n) DETACH DELETE n")
    
    m = build_model()

    return jsonify({'status': m})
Ejemplo n.º 12
0
def run_train(args):

    device = torch.device(args.device)
    # build student
    student = build_model(args.student, args.num_classes, args.pretrained)
    student = student.to(device)
    # build teachers
    teachers = build_teachers(args, device)
    # build checkpointer, optimizer, scheduler, logger
    optimizer = build_optimizer(args, student)
    scheduler = build_lr_scheduler(args, optimizer)
    checkpointer = Checkpointer(student, optimizer, scheduler, args.experiment, args.checkpoint_period)
    logger = Logger(os.path.join(args.experiment, 'tf_log'))

    # objective function to train student
    loss_fn = loss_fn_kd

    # data_load
    train_loader = CIFAR10_loader(args, is_train=True)
    test_loader = CIFAR10_loader(args, is_train=False)

    acc1, m_acc1 = inference(student, test_loader, logger, device, 0, args)
    checkpointer.best_acc = acc1
    for epoch in tqdm(range(0, args.max_epoch)):
        do_train(student, teachers, loss_fn, train_loader, optimizer, checkpointer, device, logger, epoch)
        acc1, m_acc1 = inference(student, test_loader, logger, device, epoch+1, args)
        if acc1 > checkpointer.best_acc:
            checkpointer.save("model_best")
            checkpointer.best_acc = acc1
        scheduler.step()
    
    checkpointer.save("model_last")
 def __init__(self, lr, l2, model_name, histories = [], img_shape=(384, 384, 1), step=0, use_val=True, small_dataset=False):
     self.model, self.branch_model, self.head_model = build_model(lr, l2)
     self.histories = histories
     self.step = step
     self.img_shape = img_shape
     self.img_gen = ImageGenerator()
     self.best_map5 = 0
     self.model_name = model_name
     # Make callbacklist
     self.callback_list = self.make_callback_list()
     # Load train
     if small_dataset:
         self.train = load_pickle_file(train_examples_small_file)
         print('SMALL DATASET')
     else:
         self.train = load_pickle_file(train_examples_file)
     if small_dataset:
         validation_data = load_pickle_file(validation_examples_small_file)
     else:
         validation_data = load_pickle_file(validation_examples_file)
     if use_val:
         self.validation = ValData(validation_data, self.img_gen.read_for_testing, batch_size=16)
     else:
         self.validation = None
     # Make whale to training dict
     self.w2ts = self.make_w2ts()
Ejemplo n.º 14
0
def test_train():
    model_ = models.build_model()
    # device = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = DEVICE
    dataloader = data.read_data(mean=False,
                                in_range=False,
                                val=False,
                                start_random=True)
    train_loader = data.read_data(mean=False,
                                  in_range=False,
                                  dataset="SODA",
                                  val=False)
    datas, label = data.read_test_data(mean=False, in_range=False)
    datas = datas.to(device)
    label = label.detach().numpy()
    model_.to(device)
    model_.eval()
    y = model_(datas).cpu().detach().numpy()
    print(tool.score(label, y))
    model_.train()
    train.train(model_, dataloader, train_loader)
    model_.eval()
    y = model_(datas).cpu().detach().numpy()
    print(tool.score(label, y))
    print(np.abs(y - label))
def evaluate(config):
    """Orchestrates the evaluation process.
    
    This method is responsible of executing all the steps required to train a new model, 
    which includes:
    - Preparing the dataset
    - Building the model
    - Fitting the model to the data

    Arguments:
        config (util.Config): Values for various configuration options
    """

    eval_dataset, vocabulary, coco_eval = prepare_eval_data(config)
    model = build_model(config, vocabulary)
    # start = time.time()
    # results = eval(model, eval_dataset, vocabulary, config)
    # logging.info('Total caption generation time: %d seconds', time.time() - start)

    # Evaluate these captions
    start = time.time()
    coco_eval_result = coco_eval.loadRes(config.eval_result_file)
    scorer = COCOEvalCap(coco_eval, coco_eval_result)
    scorer.evaluate()
    logging.info("Evaluation complete.")
    logging.info('Total evaluation time: %d seconds', time.time() - start)
Ejemplo n.º 16
0
    def __init__(self, model_path, gpu_id=0):
        from models import build_model
        from data_loader import get_dataloader
        from post_processing import get_post_processing
        from utils import get_metric
        self.gpu_id = gpu_id
        if self.gpu_id is not None and isinstance(self.gpu_id, int) and torch.cuda.is_available():
            self.device = torch.device("cuda:%s" % self.gpu_id)
            torch.backends.cudnn.benchmark = True
        else:
            self.device = torch.device("cpu")
        # print(self.gpu_id) 0
        checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
        config = checkpoint['config']
        config['arch']['backbone']['pretrained'] = False
        config['dataset']['train']['dataset']['args']['data_path'][0] = '/home/share/gaoluoluo/dbnet/datasets/train_zhen.txt'
        config['dataset']['validate']['dataset']['args']['data_path'][0] = '/home/share/gaoluoluo/dbnet/datasets/test_zhen.txt'

        print("config:",config)
        self.validate_loader = get_dataloader(config['dataset']['validate'], config['distributed'])

        self.model = build_model(config['arch'])
        self.model.load_state_dict(checkpoint['state_dict'])
        self.model.to(self.device)

        self.post_process = get_post_processing(config['post_processing'])
        self.metric_cls = get_metric(config['metric'])
Ejemplo n.º 17
0
def pth_params_2_ONNX():
    batch_size = 1
    model_config = {
        'backbone': {'type': 'resnet18', 'pretrained': True, "in_channels": 3},
        'neck': {'type': 'FPN', 'inner_channels': 256},  # 分割头,FPN or FPEM_FFM
        'head': {'type': 'DBHead', 'out_channels': 2, 'k': 50},
    }
    model = build_model('Model', **model_config).cuda()
    model_path = "/red_detection/DBNet/DBNet_fzh/phone_code_model/model_0.87_depoly.pth"
    model.load_state_dict(torch.load(model_path))
    model.eval()

    input_shape = (3, 736, 736)  # 输入数据,改成自己的输入shape #renet
    example = torch.randn(batch_size, *input_shape, dtype=torch.float32)  # 生成张量
    example = example.cuda()
    export_onnx_file = "/red_detection/DBNet/DBNet_fzh/phone_code_model/model_0.87_depoly.onnx"  # 目的ONNX文件名
    # torch.onnx.export(model, example, export_onnx_file, opset_version = 11, input_names = ["input"], output_names=['output'], verbose=True)
    # torch.onnx.export(model, example, export_onnx_file,\
    #                   opset_version = 10,\
    #                   do_constant_folding = True,  # 是否执行常量折叠优化\
    #                   input_names = ["input"],  # 输入名\
    #                   output_names = ["output"],  # 输出名\
    #                   dynamic_axes = {"input": {0: "batch_size"},# 批处理变量\
    #                     "output": {0: "batch_size"}})
    _ = torch.onnx.export(model,  # model being run
                          example,  # model input (or a tuple for multiple inputs)
                          export_onnx_file,
                          opset_version=10,
                          verbose=False,  # store the trained parameter weights inside the model file
                          training=False,
                          do_constant_folding=True,
                          input_names=['input'],
                          output_names=['output']
                          )
Ejemplo n.º 18
0
def main(args):
    this_dir = osp.join(osp.dirname(__file__), '.')
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    data_loader = utl.build_data_loader(args, 'extract')

    model = build_model(args).to(device)
    model.load_state_dict(torch.load(args.checkpoint))
    model.train(False)

    with torch.set_grad_enabled(False):
        for batch_idx, (data, air_target, bed_target,
                        save_path) in enumerate(data_loader):
            print('{:3.3f}%'.format(100.0 * batch_idx / len(data_loader)))
            batch_size = data.shape[0]
            data = data.to(device)
            air_feature, bed_feature = model.features(data)
            air_feature = air_feature.to('cpu').numpy()
            bed_feature = bed_feature.to('cpu').numpy()
            for bs in range(batch_size):
                if not osp.isdir(osp.dirname(save_path[bs])):
                    os.makedirs(osp.dirname(save_path[bs]))
                np.save(
                    save_path[bs],
                    np.concatenate((air_feature[bs], bed_feature[bs]), axis=0))
Ejemplo n.º 19
0
def main(args):
    # 相当于将 config 文件中的字典读取了出来,一个很实用的函数
    cfg = Config.fromfile(args.config)
    for d in [cfg, cfg.data.test]:
        d.update(dict(report_speed=args.report_speed))

    # model
    model = build_model(cfg.model)
    model = model.cuda()

    if args.checkpoint is not None:
        if os.path.isfile(args.checkpoint):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.checkpoint))
            sys.stdout.flush()

            checkpoint = torch.load(args.checkpoint)

            d = dict()
            for key, value in checkpoint['state_dict'].items():
                tmp = key[7:]
                d[tmp] = value
            model.load_state_dict(d)
        else:
            print("No checkpoint found at '{}'".format(args.resume))
            raise

    # fuse conv and bn
    model = fuse_module(model)

    print("Video Process Begin!")
    if args.online:
        test_online(model, cfg)
    else:
        test_offline(args.input, args.output, model, cfg)
Ejemplo n.º 20
0
def train(add, num_testing, object_dim, job_name, **kwargs):

	coco_train = COCO('./annotations/instances_train2014.json')
	coco_test = COCO('./annotations/instances_val2014.json')

	training_generator = create_generator(coco = coco_train, mode = 'training', add = add, object_dim = object_dim, **kwargs)
	testing_generator = create_generator(coco = coco_test, mode = 'testing', add = add, object_dim = object_dim, **kwargs)

	model = build_model(add = add, object_dim = object_dim, **kwargs)

	model.compile(loss = 'categorical_crossentropy', optimizer = Adam(1e-6, beta_1=.9, beta_2=.99), metrics = ['accuracy'])

	# callbacks_list = [ModelCheckpoint('./'+job_name+'.h5', monitor='val_loss', verbose=1, save_best_only=True, mode='min')]

	callbacks_list = [ModelCheckpoint('./'+job_name+'.h5', monitor='val_loss', verbose=1, mode='min')]

	print(model.summary())

	history = model.fit_generator(
		training_generator, \
		validation_steps = num_testing, \
		validation_data = training_generator, \
		steps_per_epoch = 100, # 5000, \
		epochs = 5, # 500, \
		# callbacks = callbacks_list,\
		verbose=1, \
		max_queue_size = 2, # 10, \
		workers = 1, \
		)
Ejemplo n.º 21
0
def load_model_state(filename, device, data_parallel=False):
    if not os.path.exists(filename):
        print("Starting training from scratch.")
        return 0

    def dict_to_sns(d):
        return SimpleNamespace(**d)

    basedir = os.path.dirname(filename)
    with open(os.path.join(basedir, 'config.json')) as f:
        args_dict = json.load(f, object_hook=dict_to_sns)

    model = build_model(args_dict, device)

    print("Loading model from checkpoints", filename)
    state = torch.load(
        filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))

    from collections import OrderedDict
    new_state_dict = OrderedDict()
    # create new OrderedDict that does not contain `module.`
    if data_parallel:
        for k, v in state['model'].items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
    else:
        new_state_dict = state['model']
    # load model parameters
    try:
        model.load_state_dict(new_state_dict)
    except Exception:
        raise Exception('Cannot load model parameters from checkpoint, '
                        'please ensure that the architectures match')
    return model, args_dict
def show_model_summary(X, y, config):

    n_feature_sets = len(X)
    model, loss = models.build_model(config)
    model.compile(loss=loss, optimizer='adam')
    model.build((None, np.concatenate(X, axis=-1).shape[1]))
    print(model.summary())
Ejemplo n.º 23
0
def test(cfg):
    """
    Perform multi-view testing on the trained video model.
    Args:
        cfg (CfgNode): configs. Details can be found in
        config.py
    """
    # Set random seed from configs.
    if cfg.RNG_SEED != -1:
        random.seed(cfg.RNG_SEED)
        np.random.seed(cfg.RNG_SEED)
        torch.manual_seed(cfg.RNG_SEED)
        torch.cuda.manual_seed_all(cfg.RNG_SEED)

    # Setup logging format.
    logging.setup_logging(cfg.NUM_GPUS)

    # Print config.
    logger.info("Test with config:")
    logger.info(pprint.pformat(cfg))

    # Model for testing
    model = build_model(cfg)
    # Print model statistics.
    if du.is_master_proc(cfg.NUM_GPUS):
        misc.log_model_info(model, cfg, use_train_input=False)

    if cfg.TEST.CHECKPOINT_FILE_PATH:
        if os.path.isfile(cfg.TEST.CHECKPOINT_FILE_PATH):
            logger.info("=> loading checkpoint '{}'".format(
                cfg.TEST.CHECKPOINT_FILE_PATH))
            ms = model.module if cfg.NUM_GPUS > 1 else model
            # Load the checkpoint on CPU to avoid GPU mem spike.
            checkpoint = torch.load(cfg.TEST.CHECKPOINT_FILE_PATH,
                                    map_location='cpu')
            ms.load_state_dict(checkpoint['state_dict'])
            logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                cfg.TEST.CHECKPOINT_FILE_PATH, checkpoint['epoch']))
    else:
        logger.info("Test with random initialization for debugging")

    # Create video testing loaders
    test_loader = loader.construct_loader(cfg, "test")
    logger.info("Testing model for {} iterations".format(len(test_loader)))

    # Create meters for multi-view testing.
    test_meter = TestMeter(
        cfg.TEST.DATASET_SIZE,
        cfg.TEST.NUM_ENSEMBLE_VIEWS * cfg.TEST.NUM_SPATIAL_CROPS,
        cfg.MODEL.NUM_CLASSES,
        len(test_loader),
        cfg.DATA.MULTI_LABEL,
        cfg.DATA.ENSEMBLE_METHOD,
        cfg.LOG_PERIOD,
    )

    cudnn.benchmark = True

    # # Perform multi-view test on the entire dataset.
    perform_test(test_loader, model, test_meter, cfg)
Ejemplo n.º 24
0
    def __init__(self, model_path, post_p_thre=0.7, gpu_id=None):
        '''
        初始化pytorch模型
        :param model_path: 模型地址(可以是模型的参数或者参数和计算图一起保存的文件)
        :param gpu_id: 在哪一块gpu上运行
        '''
        self.gpu_id = gpu_id

        if self.gpu_id is not None and isinstance(
                self.gpu_id, int) and torch.cuda.is_available():
            self.device = torch.device("cuda:%s" % self.gpu_id)
        else:
            self.device = torch.device("cpu")
        print('device:', self.device)
        checkpoint = torch.load(model_path, map_location=self.device)

        config = checkpoint['config']
        config['arch']['backbone']['pretrained'] = False
        self.model = build_model(config['arch'].pop('type'), **config['arch'])
        self.post_process = get_post_processing(config['post_processing'])
        self.post_process.box_thresh = post_p_thre
        self.img_mode = config['dataset']['train']['dataset']['args'][
            'img_mode']
        self.model.load_state_dict(checkpoint['state_dict'])
        self.model.to(self.device)
        self.model.eval()

        self.transform = []
        for t in config['dataset']['train']['dataset']['args']['transforms']:
            if t['type'] in ['ToTensor', 'Normalize']:
                self.transform.append(t)
        self.transform = get_transforms(self.transform)
def run_train(args):

    device = torch.device(args.device)

    model = build_model(args.model_name, args.num_classes, args.pretrained)
    model = model.to(device)
    # build checkpointer, optimizer, scheduler, logger
    optimizer = build_optimizer(args, model)
    scheduler = build_lr_scheduler(args, optimizer)
    checkpointer = Checkpointer(model, optimizer, scheduler, args.experiment,
                                args.checkpoint_period)
    logger = Logger(os.path.join(args.experiment, 'tf_log'))

    # data_load
    train_loader = CIFAR10_loader(args, is_train=True)
    test_loader = CIFAR10_loader(args, is_train=False)

    acc1, _ = inference(model, test_loader, logger, device, 0, args)
    checkpointer.best_acc = acc1
    for epoch in tqdm(range(0, args.max_epoch)):
        train_epoch(model, train_loader, optimizer,
                    len(train_loader) * epoch, checkpointer, device, logger)
        acc1, m_acc1 = inference(model, test_loader, logger, device, epoch + 1,
                                 args)
        if acc1 > checkpointer.best_acc:
            checkpointer.save("model_best")
            checkpointer.best_acc = acc1
        scheduler.step()

    checkpointer.save("model_last")
Ejemplo n.º 26
0
def test(config):
    if not config.config_path or not config.restore_from:
        raise AttributeError('You need to specify config_path and restore_from')
    else:
        config = load_config(config, config.config_path)

    set_logger(config)

    char_vocab = Vocab()
    char_vocab.load_from(os.path.join(config.vocab_dir, 'char_vocab.data'))
    label_vocab = Vocab(use_special_token=False)
    label_vocab.load_from(os.path.join(config.vocab_dir, 'label_vocab.data'))

    test_set = build_dataset(config, 'test', char_vocab, label_vocab)
    inputs = build_inputs(test_set.output_types, test_set.output_shapes)

    model = build_model(config, inputs)
    eval_metrics, results = model.evaluate(test_set)

    print('Eval metrics: {}'.format(eval_metrics))
    if config.result_name:
        with open(os.path.join(config.result_dir, os.path.join(config.result_name)) + '.json', 'w') as f:
            json.dump(eval_metrics, f, indent=4)

        with open(os.path.join(config.result_dir, os.path.join(config.result_name)) + '.txt', 'w') as f:
            for result in results:
                f.write(label_vocab.id2token[result] + '\n')
Ejemplo n.º 27
0
def main():
    parser = argparse.ArgumentParser(description='Evaluate model')
    parser.add_argument('--model_config',
                        default='config.json',
                        help='train config for model_weights')
    parser.add_argument('--model_weights',
                        default='./checkpoints/en_de_final.pt',
                        help='path for weights of the model')

    args = parser.parse_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    with open(os.path.join('checkpoints', args.model_config), 'rt') as f:
        model_args = argparse.Namespace()
        model_args.__dict__.update(json.load(f))
        model_args = parser.parse_args(namespace=model_args)

    train_data, valid_data, test_data, src_lang, trg_lang = prepare_data()
    model = build_model(model_args, src_lang, trg_lang, len(src_lang.vocab),
                        len(trg_lang.vocab), device)
    model.load_state_dict(torch.load(args.model_weights, map_location='cpu'))
    model.eval()

    log.info(
        'Bleu score: \n',
        calculate_bleu(test_data,
                       src_lang,
                       trg_lang,
                       model,
                       device,
                       max_len=100))
Ejemplo n.º 28
0
def main():
    args = arg_parser.parse_args()
    config_name = args.config.split("/")[-1].split(".")[0]

    with open(args.log_config) as log_config_f:
        log_filename = "logs/%s.log" % config_name
        log_config = yaml.load(log_config_f)
        log_config["handlers"]["fileHandler"]["filename"] = log_filename
        logging.config.dictConfig(log_config)

    with open(args.config) as config_f:
        config = util.Struct(**yaml.load(config_f))

    task = tasks.load_task(config.task)
    model = models.build_model(config.model, config.opt)

    for i_iter in range(config.opt.iters):
        do_eval = i_iter % 5 == 0
        train_loss, train_acc = batched_iter(
                task.train, model, config, train=True, compute_eval=do_eval)
        if do_eval:
            val_loss, val_acc = batched_iter(
                    task.val, model, config, compute_eval=True)
            visualizer.begin(config_name, 100)
            test_loss, test_acc = batched_iter(
                    task.test, model, config, compute_eval=True)
            visualizer.end()
            logging.info("%5d  :  %2.4f  %2.4f  %2.4f  :  %2.4f  %2.4f  %2.4f",
                    i_iter, train_loss, val_loss, test_loss, train_acc, val_acc, 
                    test_acc)
            model.save("saves/%s_%d.caffemodel" % (config_name, i_iter))
        else:
            logging.info("%5d  :  %2.4f", i_iter, train_loss)
Ejemplo n.º 29
0
    def init_model(self):
        args = self.args

        pretrained = True if not args.resume else False

        model = build_model(args, pretrained=pretrained)

        return model
Ejemplo n.º 30
0
def main(config_path: str = 'config.yaml'):

    typer.secho("--------------- start training ---------------",
                fg=typer.colors.GREEN)

    # load config
    config = load_yaml('config.yaml')

    typer.secho(f"loaded config : {config_path}", fg=typer.colors.GREEN)

    # load dataset
    typer.secho(f"loaded train dataset : {config['train_data']}",
                fg=typer.colors.GREEN)
    typer.secho(f"loaded test dataset : {config['test_data']}",
                fg=typer.colors.GREEN)

    train_ds = load_tfrecords(config['train_data'], config)
    test_ds = load_tfrecords(config['test_data'], config)

    train_ds = (train_ds.shuffle(
        buffer_size=config['train_image_count'] // 10).repeat().batch(
            config['batch_size']).prefetch(buffer_size=AUTOTUNE))

    test_ds = (test_ds.repeat().batch(
        config['batch_size']).prefetch(buffer_size=AUTOTUNE))

    typer.secho("--------------- build model ---------------",
                fg=typer.colors.GREEN)

    input_shape = (config['input_width'], config['input_height'],
                   config['input_chanel'])

    typer.secho(f"base model : {config['base_model']}", fg=typer.colors.GREEN)

    inputs = tf.keras.Input(shape=input_shape)

    model = build_model(config, inputs)

    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    model.summary()
    log_dir = f'logs/{config["base_model"]}'
    os.makedirs(log_dir, exist_ok=True)

    tfboard_callbacks = tf.keras.callbacks.TensorBoard(log_dir=log_dir)

    model.fit(
        train_ds,
        steps_per_epoch=config['train_image_count'] // config['batch_size'],
        validation_data=test_ds,
        validation_steps=config['test_image_count'] // config['batch_size'],
        epochs=config['epochs'],
        callbacks=[tfboard_callbacks])

    typer.secho("--------------- end training ---------------",
                fg=typer.colors.GREEN)
Ejemplo n.º 31
0
def prepare_model(args):
    model = build_model(args, pretrained=False)

    load_model(model, args.weight)

    if torch.cuda.is_available():
        model = model.cuda()

    return model
Ejemplo n.º 32
0
    def __init__(self, config, opt_config):
        self.config = config
        self.opt_config = opt_config

        self.models = []
        for cmodel in config.models:
            with open(cmodel.config) as config_f:
                mconfig = Struct(**yaml.load(config_f))
                model = models.build_model(mconfig.model, mconfig.opt)
            model.load(cmodel.weights)
            self.models.append(model)

        self.n_models = len(self.models)

        self.apollo_net = ApolloNet()
Ejemplo n.º 33
0
def main():
    config = configure()
    task = tasks.load_task(config)
    model = models.build_model(config.model, config.opt)

    for i_epoch in range(config.opt.iters):

        train_loss, train_acc, _ = do_iter(task.train, model, config, train=True)
        val_loss, val_acc, val_predictions = do_iter(task.val, model, config, vis=True)
        test_loss, test_acc, test_predictions = do_iter(task.test, model, config)

        logging.info(
            "%5d  |  %8.3f  %8.3f  %8.3f  |  %8.3f  %8.3f  %8.3f",
            i_epoch,
            train_loss,
            val_loss,
            test_loss,
            train_acc,
            val_acc,
            test_acc,
        )

        with open("logs/val_predictions_%d.json" % i_epoch, "w") as pred_f:
            print >> pred_f, json.dumps(val_predictions)
Ejemplo n.º 34
0
def main(arguments):
    ''' Train or load a model. Evaluate on some tasks. '''
    parser = argparse.ArgumentParser(description='')

    # Logistics
    parser.add_argument('--cuda', help='-1 if no CUDA, else gpu id', type=int, default=0)
    parser.add_argument('--random_seed', help='random seed to use', type=int, default=19)

    # Paths and logging
    parser.add_argument('--log_file', help='file to log to', type=str, default='log.log')
    parser.add_argument('--exp_dir', help='directory containing shared preprocessing', type=str)
    parser.add_argument('--run_dir', help='directory for saving results, models, etc.', type=str)
    parser.add_argument('--word_embs_file', help='file containing word embs', type=str, default='')
    parser.add_argument('--preproc_file', help='file containing saved preprocessing stuff',
                        type=str, default='preproc.pkl')

    # Time saving flags
    parser.add_argument('--should_train', help='1 if should train model', type=int, default=1)
    parser.add_argument('--load_model', help='1 if load from checkpoint', type=int, default=1)
    parser.add_argument('--load_epoch', help='Force loading from a certain epoch', type=int,
                        default=-1)
    parser.add_argument('--load_tasks', help='1 if load tasks', type=int, default=1)
    parser.add_argument('--load_preproc', help='1 if load vocabulary', type=int, default=1)

    # Tasks and task-specific classifiers
    parser.add_argument('--train_tasks', help='comma separated list of tasks, or "all" or "none"',
                        type=str)
    parser.add_argument('--eval_tasks', help='list of additional tasks to train a classifier,' +
                        'then evaluate on', type=str, default='')
    parser.add_argument('--classifier', help='type of classifier to use', type=str,
                        default='log_reg', choices=['log_reg', 'mlp', 'fancy_mlp'])
    parser.add_argument('--classifier_hid_dim', help='hid dim of classifier', type=int, default=512)
    parser.add_argument('--classifier_dropout', help='classifier dropout', type=float, default=0.0)

    # Preprocessing options
    parser.add_argument('--max_seq_len', help='max sequence length', type=int, default=40)
    parser.add_argument('--max_word_v_size', help='max word vocab size', type=int, default=30000)

    # Embedding options
    parser.add_argument('--dropout_embs', help='dropout rate for embeddings', type=float, default=.2)
    parser.add_argument('--d_word', help='dimension of word embeddings', type=int, default=300)
    parser.add_argument('--glove', help='1 if use glove, else from scratch', type=int, default=1)
    parser.add_argument('--train_words', help='1 if make word embs trainable', type=int, default=0)
    parser.add_argument('--elmo', help='1 if use elmo', type=int, default=0)
    parser.add_argument('--deep_elmo', help='1 if use elmo post LSTM', type=int, default=0)
    parser.add_argument('--elmo_no_glove', help='1 if no glove, assuming elmo', type=int, default=0)
    parser.add_argument('--cove', help='1 if use cove', type=int, default=0)

    # Model options
    parser.add_argument('--pair_enc', help='type of pair encoder to use', type=str, default='simple',
                        choices=['simple', 'attn'])
    parser.add_argument('--d_hid', help='hidden dimension size', type=int, default=4096)
    parser.add_argument('--n_layers_enc', help='number of RNN layers', type=int, default=1)
    parser.add_argument('--n_layers_highway', help='num of highway layers', type=int, default=1)
    parser.add_argument('--dropout', help='dropout rate to use in training', type=float, default=.2)

    # Training options
    parser.add_argument('--no_tqdm', help='1 to turn off tqdm', type=int, default=0)
    parser.add_argument('--trainer_type', help='type of trainer', type=str,
                        choices=['sampling', 'mtl'], default='sampling')
    parser.add_argument('--shared_optimizer', help='1 to use same optimizer for all tasks',
                        type=int, default=1)
    parser.add_argument('--batch_size', help='batch size', type=int, default=64)
    parser.add_argument('--optimizer', help='optimizer to use', type=str, default='sgd')
    parser.add_argument('--n_epochs', help='n epochs to train for', type=int, default=10)
    parser.add_argument('--lr', help='starting learning rate', type=float, default=1.0)
    parser.add_argument('--min_lr', help='minimum learning rate', type=float, default=1e-5)
    parser.add_argument('--max_grad_norm', help='max grad norm', type=float, default=5.)
    parser.add_argument('--weight_decay', help='weight decay value', type=float, default=0.0)
    parser.add_argument('--task_patience', help='patience in decaying per task lr',
                        type=int, default=0)
    parser.add_argument('--scheduler_threshold', help='scheduler threshold',
                        type=float, default=0.0)
    parser.add_argument('--lr_decay_factor', help='lr decay factor when val score doesn\'t improve',
                        type=float, default=.5)

    # Multi-task training options
    parser.add_argument('--val_interval', help='Number of passes between validation checks',
                        type=int, default=10)
    parser.add_argument('--max_vals', help='Maximum number of validation checks', type=int,
                        default=100)
    parser.add_argument('--bpp_method', help='if using nonsampling trainer, ' +
                        'method for calculating number of batches per pass', type=str,
                        choices=['fixed', 'percent_tr', 'proportional_rank'], default='fixed')
    parser.add_argument('--bpp_base', help='If sampling or fixed bpp' +
                        'per pass, this is the bpp. If proportional, this ' +
                        'is the smallest number', type=int, default=10)
    parser.add_argument('--weighting_method', help='Weighting method for sampling', type=str,
                        choices=['uniform', 'proportional'], default='uniform')
    parser.add_argument('--scaling_method', help='method for scaling loss', type=str,
                        choices=['min', 'max', 'unit', 'none'], default='none')
    parser.add_argument('--patience', help='patience in early stopping', type=int, default=5)
    parser.add_argument('--task_ordering', help='Method for ordering tasks', type=str, default='given',
                        choices=['given', 'random', 'random_per_pass', 'small_to_large', 'large_to_small'])

    args = parser.parse_args(arguments)

    # Logistics #
    log.basicConfig(format='%(asctime)s: %(message)s', level=log.INFO, datefmt='%m/%d %I:%M:%S %p')
    log_file = os.path.join(args.run_dir, args.log_file)
    file_handler = log.FileHandler(log_file)
    log.getLogger().addHandler(file_handler)
    log.info(args)
    seed = random.randint(1, 10000) if args.random_seed < 0 else args.random_seed
    random.seed(seed)
    torch.manual_seed(seed)
    if args.cuda >= 0:
        log.info("Using GPU %d", args.cuda)
        torch.cuda.set_device(args.cuda)
        torch.cuda.manual_seed_all(seed)
    log.info("Using random seed %d", seed)

    # Load tasks #
    log.info("Loading tasks...")
    start_time = time.time()
    train_tasks, eval_tasks, vocab, word_embs = build_tasks(args)
    tasks = train_tasks + eval_tasks
    log.info('\tFinished loading tasks in %.3fs', time.time() - start_time)

    # Build model #
    log.info('Building model...')
    start_time = time.time()
    model = build_model(args, vocab, word_embs, tasks)
    log.info('\tFinished building model in %.3fs', time.time() - start_time)

    # Set up trainer #
    # TODO(Alex): move iterator creation
    iterator = BasicIterator(args.batch_size)
    #iterator = BucketIterator(sorting_keys=[("sentence1", "num_tokens")], batch_size=args.batch_size)
    trainer, train_params, opt_params, schd_params = build_trainer(args, args.trainer_type, model, iterator)

    # Train #
    if train_tasks and args.should_train:
        #to_train = [p for p in model.parameters() if p.requires_grad]
        to_train = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
        if args.trainer_type == 'mtl':
            best_epochs = trainer.train(train_tasks, args.task_ordering, args.val_interval,
                                        args.max_vals, args.bpp_method, args.bpp_base, to_train,
                                        opt_params, schd_params, args.load_model)
        elif args.trainer_type == 'sampling':
            if args.weighting_method == 'uniform':
                log.info("Sampling tasks uniformly")
            elif args.weighting_method == 'proportional':
                log.info("Sampling tasks proportional to number of training batches")

            if args.scaling_method == 'max':
                # divide by # batches, multiply by max # batches
                log.info("Scaling losses to largest task")
            elif args.scaling_method == 'min':
                # divide by # batches, multiply by fewest # batches
                log.info("Scaling losses to the smallest task")
            elif args.scaling_method == 'unit':
                log.info("Dividing losses by number of training batches")
            best_epochs = trainer.train(train_tasks, args.val_interval, args.bpp_base,
                                        args.weighting_method, args.scaling_method, to_train,
                                        opt_params, schd_params, args.shared_optimizer,
                                        args.load_model)
    else:
        log.info("Skipping training.")
        best_epochs = {}

    # train just the classifiers for eval tasks
    for task in eval_tasks:
        pred_layer = getattr(model, "%s_pred_layer" % task.name)
        to_train = pred_layer.parameters()
        trainer = MultiTaskTrainer.from_params(model, args.run_dir + '/%s/' % task.name,
                                               iterator, copy.deepcopy(train_params))
        trainer.train([task], args.task_ordering, 1, args.max_vals, 'percent_tr', 1, to_train,
                      opt_params, schd_params, 1)
        layer_path = os.path.join(args.run_dir, task.name, "%s_best.th" % task.name)
        layer_state = torch.load(layer_path, map_location=device_mapping(args.cuda))
        model.load_state_dict(layer_state)

    # Evaluate: load the different task best models and evaluate them
    # TODO(Alex): put this in evaluate file
    all_results = {}

    if not best_epochs and args.load_epoch >= 0:
        epoch_to_load = args.load_epoch
    elif not best_epochs and not args.load_epoch:
        serialization_files = os.listdir(args.run_dir)
        model_checkpoints = [x for x in serialization_files if "model_state_epoch" in x]
        epoch_to_load = max([int(x.split("model_state_epoch_")[-1].strip(".th")) \
                             for x in model_checkpoints])
    else:
        epoch_to_load = -1

    #for task in [task.name for task in train_tasks] + ['micro', 'macro']:
    for task in ['macro']:
        log.info("Testing on %s..." % task)

        # Load best model
        load_idx = best_epochs[task] if best_epochs else epoch_to_load
        model_path = os.path.join(args.run_dir, "model_state_epoch_{}.th".format(load_idx))
        model_state = torch.load(model_path, map_location=device_mapping(args.cuda))
        model.load_state_dict(model_state)

        # Test evaluation and prediction
        # could just filter out tasks to get what i want...
        #tasks = [task for task in tasks if 'mnli' in task.name]
        te_results, te_preds = evaluate(model, tasks, iterator, cuda_device=args.cuda, split="test")
        val_results, _ = evaluate(model, tasks, iterator, cuda_device=args.cuda, split="val")

        if task == 'macro':
            all_results[task] = (val_results, te_results, model_path)
            for eval_task, task_preds in te_preds.items(): # write predictions for each task
                #if 'mnli' not in eval_task:
                #    continue
                idxs_and_preds = [(idx, pred) for pred, idx in zip(task_preds[0], task_preds[1])]
                idxs_and_preds.sort(key=lambda x: x[0])
                if 'mnli' in eval_task:
                    pred_map = {0: 'neutral', 1: 'entailment', 2: 'contradiction'}
                    with open(os.path.join(args.run_dir, "%s-m.tsv" % (eval_task)), 'w') as pred_fh:
                        pred_fh.write("index\tprediction\n")
                        split_idx = 0
                        for idx, pred in idxs_and_preds[:9796]:
                            pred = pred_map[pred]
                            pred_fh.write("%d\t%s\n" % (split_idx, pred))
                            split_idx += 1
                    with open(os.path.join(args.run_dir, "%s-mm.tsv" % (eval_task)), 'w') as pred_fh:
                        pred_fh.write("index\tprediction\n")
                        split_idx = 0
                        for idx, pred in idxs_and_preds[9796:9796+9847]:
                            pred = pred_map[pred]
                            pred_fh.write("%d\t%s\n" % (split_idx, pred))
                            split_idx += 1
                    with open(os.path.join(args.run_dir, "diagnostic.tsv"), 'w') as pred_fh:
                        pred_fh.write("index\tprediction\n")
                        split_idx = 0
                        for idx, pred in idxs_and_preds[9796+9847:]:
                            pred = pred_map[pred]
                            pred_fh.write("%d\t%s\n" % (split_idx, pred))
                            split_idx += 1
                else:
                    with open(os.path.join(args.run_dir, "%s.tsv" % (eval_task)), 'w') as pred_fh:
                        pred_fh.write("index\tprediction\n")
                        for idx, pred in idxs_and_preds:
                            if 'sts-b' in eval_task:
                                pred_fh.write("%d\t%.3f\n" % (idx, pred))
                            elif 'rte' in eval_task:
                                pred = 'entailment' if pred else 'not_entailment'
                                pred_fh.write('%d\t%s\n' % (idx, pred))
                            elif 'squad' in eval_task:
                                pred = 'entailment' if pred else 'not_entailment'
                                pred_fh.write('%d\t%s\n' % (idx, pred))
                            else:
                                pred_fh.write("%d\t%d\n" % (idx, pred))

            with open(os.path.join(args.exp_dir, "results.tsv"), 'a') as results_fh: # aggregate results easily
                run_name = args.run_dir.split('/')[-1]
                all_metrics_str = ', '.join(['%s: %.3f' % (metric, score) for \
                                            metric, score in val_results.items()])
                results_fh.write("%s\t%s\n" % (run_name, all_metrics_str))
    log.info("Done testing")

    # Dump everything to a pickle for posterity
    pkl.dump(all_results, open(os.path.join(args.run_dir, "results.pkl"), 'wb'))