Example #1
0
def main(opt, device_id):
  opt = training_opt_postprocessing(opt, device_id)
  init_logger(opt.log_file)
  # Load checkpoint if we resume from a previous training.
  if opt.train_from:
    logger.info('Loading checkpoint from %s' % opt.train_from)
    checkpoint = torch.load(opt.train_from,
                            map_location=lambda storage, loc: storage)

    # Load default opts values then overwrite it with opts from
    # the checkpoint. It's usefull in order to re-train a model
    # after adding a new option (not set in checkpoint)
    dummy_parser = configargparse.ArgumentParser()
    opts.model_opts(dummy_parser)
    default_opt = dummy_parser.parse_known_args([])[0]

    model_opt = default_opt
    model_opt.__dict__.update(checkpoint['opt'].__dict__)
  else:
    checkpoint = None
    model_opt = opt

  # Load fields generated from preprocess phase.
  fields = load_fields(opt, checkpoint)

  # Build model.
  model = build_model(model_opt, opt, fields, checkpoint)
  n_params, enc, dec = _tally_parameters(model)
  logger.info('encoder: %d' % enc)
  logger.info('decoder: %d' % dec)
  logger.info('* number of parameters: %d' % n_params)
  _check_save_model_path(opt)

  # Build optimizer.
  optim = build_optim(model, opt, checkpoint)

  # Build model saver
  model_saver = build_model_saver(model_opt, opt, model, fields, optim)

  trainer = build_trainer(opt, device_id, model, fields,
                          optim, model_saver=model_saver)

  def train_iter_fct(): 
    return build_dataset_iter(
      load_dataset("train", opt), fields, opt)

  def valid_iter_fct(): 
    return build_dataset_iter(
      load_dataset("valid", opt), fields, opt, is_train=False)

  # Do training.
  if len(opt.gpu_ranks):
    logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
  else:
    logger.info('Starting training on CPU, could be very slow')
  trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps,
                opt.valid_steps)

  if opt.tensorboard:
    trainer.report_manager.tensorboard_writer.close()
Example #2
0
def main(args):
    #### set up cfg ####
    # default cfg
    cfg = get_cfg()

    # add registered cfg
    cfg = build_config(cfg, args.config_name)
    cfg.setup(args)

    #### seed ####
    SEED = cfg.seed
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    #### start searching ####
    trainer = build_trainer(cfg)
    try:
        trainer.train(cfg.trainer.validate_always)
        if not cfg.trainer.validate_always:
            trainer.test()
    except (KeyboardInterrupt, ) as e:
        if isinstance(e, KeyboardInterrupt):
            print(f'Capture KeyboardInterrupt event ...')
        else:
            print(str(e))
    finally:
        trainer.save_cfg()
Example #3
0
def train(config: str) -> None:
    yaml_file = yaml.load(open(config).read(), Loader=yaml.FullLoader)
    # Build Trainer
    train_configs = TrainerConfig(yaml_file)
    seed_everything(train_configs.seed)
    trainer = build_trainer(train_configs.namespace())

    # Build Model
    model_config = PersonaGPT2.ModelConfig(yaml_file)
    model = PersonaGPT2(model_config.namespace())

    ## Pruning
    if model_config.prune_mask:
        click.secho("Applying pruning mask..", fg="yellow")
        ## Load mask dict
        mask = torch.load(model_config.prune_mask)
        ## Apply masks directly to weights and permanently prune them
        for key in mask.keys():
            attrs = key.split(".")
            last_attr = attrs[-1][:-5]  # don't count with "_mask"
            obj = model.gpt2
            for attr in attrs[:-1]:
                if attr.isdigit():
                    obj = obj[int(attr)]
                else:
                    obj = getattr(obj, attr)
            tensor_mask = mask[key].to(torch.device("cpu"))
            nn.utils.prune.custom_from_mask(obj,
                                            name=last_attr,
                                            mask=tensor_mask)
            nn.utils.prune.remove(obj, last_attr)
        click.secho("Model pruned.", fg="yellow")

    data = DataModule(model.hparams, model.tokenizer)
    trainer.fit(model, data)

    ## Quantization
    if train_configs.quantize:
        click.secho("Saving quantized model..", fg="yellow")
        torch.save(
            model.state_dict(),
            os.path.join("experiments/", trainer.logger.version,
                         "8bit_model.pt"))
Example #4
0
s = sampler.Sampler(wf, h, mag0=m)
s.run(nruns, init_state=state)

state = s.state
wf.init_lt(state)


def gamma_fun(p):
    #return .05
    return max(.05 * (.994**p),
               .005)  #This is chosen to give a factor of 10 in about 400 steps
    #return .05 / (2 ** (p // 50))


t = trainer.build_trainer(wf, h)

wf, elist = t.train(wf,
                    state,
                    nruns,
                    nsteps,
                    gamma_fun,
                    file='../Outputs/' + str(nspins) + '_alpha=' + str(alpha) +
                    '_Ising10_ti_',
                    out_freq=25)

#h = ising1d.Ising1d(40,1)
s = sampler.Sampler(wf, h)
s.run(nruns)

end = time.time()
Example #5
0
        attention_probs_dropout_prob=args.attention_probs_dropout_prob,
        device=args.device)
    optimizer, scheduler = build_optimizer_scheduler(
        model=model,
        lr=args.lr,
        betas=(args.beta1, args.beta2),
        eps=args.eps,
        warmup_ratio=args.warmup_ratio,
        weight_decay=args.weight_decay,
        layer_decrease=args.layer_decrease,
        freeze_upto=args.freeze_upto,
        train_step=args.train_step)
    trainer = build_trainer(model=model,
                            data=olid_data,
                            optimizer=optimizer,
                            scheduler=scheduler,
                            max_grad_norm=args.max_grad_norm,
                            patience=args.patience,
                            record_every=args.record_every,
                            exp_name=exp_name)

    logger.info(f'Training logs are in {exp_name}')
    trained_model, summary = trainer.train(args.train_step)

    best_model_file = os.path.join(trainer.exp_dir, 'best_model.pt')
    pred_file = os.path.join(trainer.exp_dir, 'prediction.tsv')
    summary_file = os.path.join(trainer.exp_dir, 'summary.txt')
    args_file = os.path.join(trainer.exp_dir, 'args.bin')
    save_model(trained_model, best_model_file)
    save_tokenizer(tokenizer, trainer.exp_dir)
    write_pred_to_file(trained_model, trainer.test_iter, tokenizer, pred_file)
    write_args_to_file(args, args_file)
Example #6
0
def main(arguments):
    parser = argparse.ArgumentParser(description='')

    parser.add_argument(
        '--cuda',
        help='-1 if no CUDA, else gpu id (single gpu is enough)',
        type=int,
        default=0)
    parser.add_argument('--random_seed',
                        help='random seed to use',
                        type=int,
                        default=111)

    # Paths and logging
    parser.add_argument('--log_file',
                        help='file to log to',
                        type=str,
                        default='training.log')
    parser.add_argument('--store_root',
                        help='store root path',
                        type=str,
                        default='checkpoint')
    parser.add_argument('--store_name',
                        help='store name prefix for current experiment',
                        type=str,
                        default='sts')
    parser.add_argument('--suffix',
                        help='store name suffix for current experiment',
                        type=str,
                        default='')
    parser.add_argument('--word_embs_file',
                        help='file containing word embs',
                        type=str,
                        default='glove/glove.840B.300d.txt')

    # Training resuming flag
    parser.add_argument('--resume',
                        help='whether to resume training',
                        action='store_true',
                        default=False)

    # Tasks
    parser.add_argument('--task',
                        help='training and evaluation task',
                        type=str,
                        default='sts-b')

    # Preprocessing options
    parser.add_argument('--max_seq_len',
                        help='max sequence length',
                        type=int,
                        default=40)
    parser.add_argument('--max_word_v_size',
                        help='max word vocab size',
                        type=int,
                        default=30000)

    # Embedding options
    parser.add_argument('--dropout_embs',
                        help='dropout rate for embeddings',
                        type=float,
                        default=.2)
    parser.add_argument('--d_word',
                        help='dimension of word embeddings',
                        type=int,
                        default=300)
    parser.add_argument('--glove',
                        help='1 if use glove, else from scratch',
                        type=int,
                        default=1)
    parser.add_argument('--train_words',
                        help='1 if make word embs trainable',
                        type=int,
                        default=0)

    # Model options
    parser.add_argument('--d_hid',
                        help='hidden dimension size',
                        type=int,
                        default=1500)
    parser.add_argument('--n_layers_enc',
                        help='number of RNN layers',
                        type=int,
                        default=2)
    parser.add_argument('--n_layers_highway',
                        help='number of highway layers',
                        type=int,
                        default=0)
    parser.add_argument('--dropout',
                        help='dropout rate to use in training',
                        type=float,
                        default=0.2)

    # Training options
    parser.add_argument('--batch_size',
                        help='batch size',
                        type=int,
                        default=128)
    parser.add_argument('--optimizer',
                        help='optimizer to use',
                        type=str,
                        default='adam')
    parser.add_argument('--lr',
                        help='starting learning rate',
                        type=float,
                        default=1e-4)
    parser.add_argument(
        '--loss',
        type=str,
        default='mse',
        choices=['mse', 'l1', 'focal_l1', 'focal_mse', 'huber'])
    parser.add_argument('--huber_beta',
                        type=float,
                        default=0.3,
                        help='beta for huber loss')
    parser.add_argument('--max_grad_norm',
                        help='max grad norm',
                        type=float,
                        default=5.)
    parser.add_argument('--val_interval',
                        help='number of iterations between validation checks',
                        type=int,
                        default=400)
    parser.add_argument('--max_vals',
                        help='maximum number of validation checks',
                        type=int,
                        default=100)
    parser.add_argument('--patience',
                        help='patience for early stopping',
                        type=int,
                        default=10)

    # imbalanced related
    # LDS
    parser.add_argument('--lds',
                        action='store_true',
                        default=False,
                        help='whether to enable LDS')
    parser.add_argument('--lds_kernel',
                        type=str,
                        default='gaussian',
                        choices=['gaussian', 'triang', 'laplace'],
                        help='LDS kernel type')
    parser.add_argument('--lds_ks',
                        type=int,
                        default=5,
                        help='LDS kernel size: should be odd number')
    parser.add_argument('--lds_sigma',
                        type=float,
                        default=2,
                        help='LDS gaussian/laplace kernel sigma')
    # FDS
    parser.add_argument('--fds',
                        action='store_true',
                        default=False,
                        help='whether to enable FDS')
    parser.add_argument('--fds_kernel',
                        type=str,
                        default='gaussian',
                        choices=['gaussian', 'triang', 'laplace'],
                        help='FDS kernel type')
    parser.add_argument('--fds_ks',
                        type=int,
                        default=5,
                        help='FDS kernel size: should be odd number')
    parser.add_argument('--fds_sigma',
                        type=float,
                        default=2,
                        help='FDS gaussian/laplace kernel sigma')
    parser.add_argument('--start_update',
                        type=int,
                        default=0,
                        help='which epoch to start FDS updating')
    parser.add_argument(
        '--start_smooth',
        type=int,
        default=1,
        help='which epoch to start using FDS to smooth features')
    parser.add_argument('--bucket_num',
                        type=int,
                        default=50,
                        help='maximum bucket considered for FDS')
    parser.add_argument('--bucket_start',
                        type=int,
                        default=0,
                        help='minimum(starting) bucket for FDS')
    parser.add_argument('--fds_mmt',
                        type=float,
                        default=0.9,
                        help='FDS momentum')

    # re-weighting: SQRT_INV / INV
    parser.add_argument('--reweight',
                        type=str,
                        default='none',
                        choices=['none', 'sqrt_inv', 'inverse'],
                        help='cost-sensitive reweighting scheme')
    # two-stage training: RRT
    parser.add_argument(
        '--retrain_fc',
        action='store_true',
        default=False,
        help='whether to retrain last regression layer (regressor)')
    parser.add_argument(
        '--pretrained',
        type=str,
        default='',
        help='pretrained checkpoint file path to load backbone weights for RRT'
    )
    # evaluate only
    parser.add_argument('--evaluate',
                        action='store_true',
                        default=False,
                        help='evaluate only flag')
    parser.add_argument('--eval_model',
                        type=str,
                        default='',
                        help='the model to evaluate on; if not specified, '
                        'use the default best model in store_dir')

    args = parser.parse_args(arguments)

    os.makedirs(args.store_root, exist_ok=True)

    if not args.lds and args.reweight != 'none':
        args.store_name += f'_{args.reweight}'
    if args.lds:
        args.store_name += f'_lds_{args.lds_kernel[:3]}_{args.lds_ks}'
        if args.lds_kernel in ['gaussian', 'laplace']:
            args.store_name += f'_{args.lds_sigma}'
    if args.fds:
        args.store_name += f'_fds_{args.fds_kernel[:3]}_{args.fds_ks}'
        if args.fds_kernel in ['gaussian', 'laplace']:
            args.store_name += f'_{args.fds_sigma}'
        args.store_name += f'_{args.start_update}_{args.start_smooth}_{args.fds_mmt}'
    if args.retrain_fc:
        args.store_name += f'_retrain_fc'

    if args.loss == 'huber':
        args.store_name += f'_{args.loss}_beta_{args.huber_beta}'
    else:
        args.store_name += f'_{args.loss}'

    args.store_name += f'_seed_{args.random_seed}_valint_{args.val_interval}_patience_{args.patience}' \
                       f'_{args.optimizer}_{args.lr}_{args.batch_size}'
    args.store_name += f'_{args.suffix}' if len(args.suffix) else ''

    args.store_dir = os.path.join(args.store_root, args.store_name)

    if not args.evaluate and not args.resume:
        if os.path.exists(args.store_dir):
            if query_yes_no('overwrite previous folder: {} ?'.format(
                    args.store_dir)):
                shutil.rmtree(args.store_dir)
                print(args.store_dir + ' removed.\n')
            else:
                raise RuntimeError('Output folder {} already exists'.format(
                    args.store_dir))
        logging.info(f"===> Creating folder: {args.store_dir}")
        os.makedirs(args.store_dir)

    # Logistics
    logging.root.handlers = []
    if os.path.exists(args.store_dir):
        log_file = os.path.join(args.store_dir, args.log_file)
        logging.basicConfig(
            level=logging.INFO,
            format="%(asctime)s | %(message)s",
            handlers=[logging.FileHandler(log_file),
                      logging.StreamHandler()])
    else:
        logging.basicConfig(level=logging.INFO,
                            format="%(asctime)s | %(message)s",
                            handlers=[logging.StreamHandler()])
    logging.info(args)

    seed = random.randint(1,
                          10000) if args.random_seed < 0 else args.random_seed
    random.seed(seed)
    torch.manual_seed(seed)
    if args.cuda >= 0:
        logging.info("Using GPU %d", args.cuda)
        torch.cuda.set_device(args.cuda)
        torch.cuda.manual_seed_all(seed)
    logging.info("Using random seed %d", seed)

    # Load tasks
    logging.info("Loading tasks...")
    start_time = time.time()
    tasks, vocab, word_embs = build_tasks(args)
    logging.info('\tFinished loading tasks in %.3fs', time.time() - start_time)

    # Build model
    logging.info('Building model...')
    start_time = time.time()
    model = build_model(args, vocab, word_embs, tasks)
    logging.info('\tFinished building model in %.3fs',
                 time.time() - start_time)

    # Set up trainer
    iterator = BasicIterator(args.batch_size)
    trainer, train_params, opt_params = build_trainer(args, model, iterator)

    # Train
    if tasks and not args.evaluate:
        if args.retrain_fc and len(args.pretrained):
            model_path = args.pretrained
            assert os.path.isfile(
                model_path), f"No checkpoint found at '{model_path}'"
            model_state = torch.load(model_path,
                                     map_location=device_mapping(args.cuda))
            trainer._model = resume_checkpoint(trainer._model,
                                               model_state,
                                               backbone_only=True)
            logging.info(f'Pre-trained backbone weights loaded: {model_path}')
            logging.info('Retrain last regression layer only!')
            for name, param in trainer._model.named_parameters():
                if "sts-b_pred_layer" not in name:
                    param.requires_grad = False
            logging.info(
                f'Only optimize parameters: {[n for n, p in trainer._model.named_parameters() if p.requires_grad]}'
            )
            to_train = [(n, p) for n, p in trainer._model.named_parameters()
                        if p.requires_grad]
        else:
            to_train = [(n, p) for n, p in model.named_parameters()
                        if p.requires_grad]

        trainer.train(tasks, args.val_interval, to_train, opt_params,
                      args.resume)
    else:
        logging.info("Skipping training...")

    logging.info('Testing on test set...')
    model_path = os.path.join(
        args.store_dir,
        "model_state_best.th") if not len(args.eval_model) else args.eval_model
    assert os.path.isfile(model_path), f"No checkpoint found at '{model_path}'"
    logging.info(f'Evaluating {model_path}...')
    model_state = torch.load(model_path,
                             map_location=device_mapping(args.cuda))
    model = resume_checkpoint(model, model_state)
    te_preds, te_labels, _ = evaluate(model,
                                      tasks,
                                      iterator,
                                      cuda_device=args.cuda,
                                      split="test")
    if not len(args.eval_model):
        np.savez_compressed(os.path.join(args.store_dir,
                                         f"{args.store_name}.npz"),
                            preds=te_preds,
                            labels=te_labels)

    logging.info("Done testing.")
Example #7
0
def main(opt, device_id):
  # device_id = -1
  # 初始化gpu
  opt = training_opt_postprocessing(opt, device_id)
  init_logger(opt.log_file)
  # Load checkpoint if we resume from a previous training.
  if opt.train_from:
    logger.info('Loading checkpoint from %s' % opt.train_from)
    # Load all tensors onto the CPU
    checkpoint = torch.load(opt.train_from,
                            map_location=lambda storage, loc: storage)

    # Load default opts values then overwrite it with opts from
    # the checkpoint. It's usefull in order to re-train a model
    # after adding a new option (not set in checkpoint)
    dummy_parser = configargparse.ArgumentParser()
    opts.model_opts(dummy_parser)
    # 返回值为两个,第一个与parse_args()返回值类型相同
    default_opt = dummy_parser.parse_known_args([])[0]
    model_opt = default_opt
    # 把opt中原有的选项也加入新的参数列表中
    # 也就是说选项只可以增加而不可以删除或者修改, 
    # 如果是这样,那么后文就不需要opt了?
    model_opt.__dict__.update(checkpoint['opt'].__dict__)
  else:
    # 第一次载入
    checkpoint = None
    model_opt = opt

  # Load fields generated from preprocess phase.
  # {"src": Field, "tgt": Field, "indices": Field}
  # Field中最重要的是vocab属性,其中包含freqs、itos、stoi
  # freqs是词频,不包含特殊字符
  # src : stoi中含有<unk>、<blank>, 不含<s>与</s>
  # tgt : stoi含有<unk>、<blank>、<s>、</s>
  # <unk> = 0, <blank>(pad) = 1
  fields = load_fields(opt, checkpoint)

  # Build model.
  # 第一次应该不需要opt参数,可用model_opt代替
  model = build_model(model_opt, opt, fields, checkpoint)
  # for name, param in model.named_parameters():
  #   if param.requires_grad:
  #       print(name)
  n_params, enc, dec = _tally_parameters(model)
  logger.info('encoder: %d' % enc)
  logger.info('decoder: %d' % dec)
  logger.info('* number of parameters: %d' % n_params)
  # 没有模型保存目录则创建该目录
  _check_save_model_path(opt)

  # Build optimizer.
  optim = build_optim(model, opt, checkpoint)

  # Build model saver
  model_saver = build_model_saver(model_opt, opt, model, fields, optim)

  trainer = build_trainer(opt, device_id, model, fields,
                          optim, model_saver=model_saver)
  # 打印模型所有参数
  # for name, param in model.named_parameters():
  #   if param.requires_grad:
  #       print(param)
      
  def train_iter_fct(): 
    return build_dataset_iter(
      load_dataset("train", opt), fields, opt)

  def valid_iter_fct(): 
    return build_dataset_iter(
      load_dataset("valid", opt), fields, opt, is_train=False)

  # Do training.
  if len(opt.gpu_ranks):
    logger.info('Starting training on GPU: %s' % opt.gpu_ranks)
  else:
    logger.info('Starting training on CPU, could be very slow')
  trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps,
                opt.valid_steps)

  if opt.tensorboard:
    trainer.report_manager.tensorboard_writer.close()
Example #8
0
def main(arguments):
    ''' Train or load a model. Evaluate on some tasks. '''
    parser = argparse.ArgumentParser(description='')

    # Logistics
    parser.add_argument('--cuda', help='-1 if no CUDA, else gpu id', type=int, default=0)
    parser.add_argument('--random_seed', help='random seed to use', type=int, default=19)

    # Paths and logging
    parser.add_argument('--log_file', help='file to log to', type=str, default='log.log')
    parser.add_argument('--exp_dir', help='directory containing shared preprocessing', type=str)
    parser.add_argument('--run_dir', help='directory for saving results, models, etc.', type=str)
    parser.add_argument('--word_embs_file', help='file containing word embs', type=str, default='')
    parser.add_argument('--preproc_file', help='file containing saved preprocessing stuff',
                        type=str, default='preproc.pkl')

    # Time saving flags
    parser.add_argument('--should_train', help='1 if should train model', type=int, default=1)
    parser.add_argument('--load_model', help='1 if load from checkpoint', type=int, default=1)
    parser.add_argument('--load_epoch', help='Force loading from a certain epoch', type=int,
                        default=-1)
    parser.add_argument('--load_tasks', help='1 if load tasks', type=int, default=1)
    parser.add_argument('--load_preproc', help='1 if load vocabulary', type=int, default=1)

    # Tasks and task-specific classifiers
    parser.add_argument('--train_tasks', help='comma separated list of tasks, or "all" or "none"',
                        type=str)
    parser.add_argument('--eval_tasks', help='list of additional tasks to train a classifier,' +
                        'then evaluate on', type=str, default='')
    parser.add_argument('--classifier', help='type of classifier to use', type=str,
                        default='log_reg', choices=['log_reg', 'mlp', 'fancy_mlp'])
    parser.add_argument('--classifier_hid_dim', help='hid dim of classifier', type=int, default=512)
    parser.add_argument('--classifier_dropout', help='classifier dropout', type=float, default=0.0)

    # Preprocessing options
    parser.add_argument('--max_seq_len', help='max sequence length', type=int, default=40)
    parser.add_argument('--max_word_v_size', help='max word vocab size', type=int, default=30000)

    # Embedding options
    parser.add_argument('--dropout_embs', help='dropout rate for embeddings', type=float, default=.2)
    parser.add_argument('--d_word', help='dimension of word embeddings', type=int, default=300)
    parser.add_argument('--glove', help='1 if use glove, else from scratch', type=int, default=1)
    parser.add_argument('--train_words', help='1 if make word embs trainable', type=int, default=0)
    parser.add_argument('--elmo', help='1 if use elmo', type=int, default=0)
    parser.add_argument('--deep_elmo', help='1 if use elmo post LSTM', type=int, default=0)
    parser.add_argument('--elmo_no_glove', help='1 if no glove, assuming elmo', type=int, default=0)
    parser.add_argument('--cove', help='1 if use cove', type=int, default=0)

    # Model options
    parser.add_argument('--pair_enc', help='type of pair encoder to use', type=str, default='simple',
                        choices=['simple', 'attn'])
    parser.add_argument('--d_hid', help='hidden dimension size', type=int, default=4096)
    parser.add_argument('--n_layers_enc', help='number of RNN layers', type=int, default=1)
    parser.add_argument('--n_layers_highway', help='num of highway layers', type=int, default=1)
    parser.add_argument('--dropout', help='dropout rate to use in training', type=float, default=.2)

    # Training options
    parser.add_argument('--no_tqdm', help='1 to turn off tqdm', type=int, default=0)
    parser.add_argument('--trainer_type', help='type of trainer', type=str,
                        choices=['sampling', 'mtl'], default='sampling')
    parser.add_argument('--shared_optimizer', help='1 to use same optimizer for all tasks',
                        type=int, default=1)
    parser.add_argument('--batch_size', help='batch size', type=int, default=64)
    parser.add_argument('--optimizer', help='optimizer to use', type=str, default='sgd')
    parser.add_argument('--n_epochs', help='n epochs to train for', type=int, default=10)
    parser.add_argument('--lr', help='starting learning rate', type=float, default=1.0)
    parser.add_argument('--min_lr', help='minimum learning rate', type=float, default=1e-5)
    parser.add_argument('--max_grad_norm', help='max grad norm', type=float, default=5.)
    parser.add_argument('--weight_decay', help='weight decay value', type=float, default=0.0)
    parser.add_argument('--task_patience', help='patience in decaying per task lr',
                        type=int, default=0)
    parser.add_argument('--scheduler_threshold', help='scheduler threshold',
                        type=float, default=0.0)
    parser.add_argument('--lr_decay_factor', help='lr decay factor when val score doesn\'t improve',
                        type=float, default=.5)

    # Multi-task training options
    parser.add_argument('--val_interval', help='Number of passes between validation checks',
                        type=int, default=10)
    parser.add_argument('--max_vals', help='Maximum number of validation checks', type=int,
                        default=100)
    parser.add_argument('--bpp_method', help='if using nonsampling trainer, ' +
                        'method for calculating number of batches per pass', type=str,
                        choices=['fixed', 'percent_tr', 'proportional_rank'], default='fixed')
    parser.add_argument('--bpp_base', help='If sampling or fixed bpp' +
                        'per pass, this is the bpp. If proportional, this ' +
                        'is the smallest number', type=int, default=10)
    parser.add_argument('--weighting_method', help='Weighting method for sampling', type=str,
                        choices=['uniform', 'proportional'], default='uniform')
    parser.add_argument('--scaling_method', help='method for scaling loss', type=str,
                        choices=['min', 'max', 'unit', 'none'], default='none')
    parser.add_argument('--patience', help='patience in early stopping', type=int, default=5)
    parser.add_argument('--task_ordering', help='Method for ordering tasks', type=str, default='given',
                        choices=['given', 'random', 'random_per_pass', 'small_to_large', 'large_to_small'])

    args = parser.parse_args(arguments)

    # Logistics #
    log.basicConfig(format='%(asctime)s: %(message)s', level=log.INFO, datefmt='%m/%d %I:%M:%S %p')
    log_file = os.path.join(args.run_dir, args.log_file)
    file_handler = log.FileHandler(log_file)
    log.getLogger().addHandler(file_handler)
    log.info(args)
    seed = random.randint(1, 10000) if args.random_seed < 0 else args.random_seed
    random.seed(seed)
    torch.manual_seed(seed)
    if args.cuda >= 0:
        log.info("Using GPU %d", args.cuda)
        torch.cuda.set_device(args.cuda)
        torch.cuda.manual_seed_all(seed)
    log.info("Using random seed %d", seed)

    # Load tasks #
    log.info("Loading tasks...")
    start_time = time.time()
    train_tasks, eval_tasks, vocab, word_embs = build_tasks(args)
    tasks = train_tasks + eval_tasks
    log.info('\tFinished loading tasks in %.3fs', time.time() - start_time)

    # Build model #
    log.info('Building model...')
    start_time = time.time()
    model = build_model(args, vocab, word_embs, tasks)
    log.info('\tFinished building model in %.3fs', time.time() - start_time)

    # Set up trainer #
    # TODO(Alex): move iterator creation
    iterator = BasicIterator(args.batch_size)
    #iterator = BucketIterator(sorting_keys=[("sentence1", "num_tokens")], batch_size=args.batch_size)
    trainer, train_params, opt_params, schd_params = build_trainer(args, args.trainer_type, model, iterator)

    # Train #
    if train_tasks and args.should_train:
        #to_train = [p for p in model.parameters() if p.requires_grad]
        to_train = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
        if args.trainer_type == 'mtl':
            best_epochs = trainer.train(train_tasks, args.task_ordering, args.val_interval,
                                        args.max_vals, args.bpp_method, args.bpp_base, to_train,
                                        opt_params, schd_params, args.load_model)
        elif args.trainer_type == 'sampling':
            if args.weighting_method == 'uniform':
                log.info("Sampling tasks uniformly")
            elif args.weighting_method == 'proportional':
                log.info("Sampling tasks proportional to number of training batches")

            if args.scaling_method == 'max':
                # divide by # batches, multiply by max # batches
                log.info("Scaling losses to largest task")
            elif args.scaling_method == 'min':
                # divide by # batches, multiply by fewest # batches
                log.info("Scaling losses to the smallest task")
            elif args.scaling_method == 'unit':
                log.info("Dividing losses by number of training batches")
            best_epochs = trainer.train(train_tasks, args.val_interval, args.bpp_base,
                                        args.weighting_method, args.scaling_method, to_train,
                                        opt_params, schd_params, args.shared_optimizer,
                                        args.load_model)
    else:
        log.info("Skipping training.")
        best_epochs = {}

    # train just the classifiers for eval tasks
    for task in eval_tasks:
        pred_layer = getattr(model, "%s_pred_layer" % task.name)
        to_train = pred_layer.parameters()
        trainer = MultiTaskTrainer.from_params(model, args.run_dir + '/%s/' % task.name,
                                               iterator, copy.deepcopy(train_params))
        trainer.train([task], args.task_ordering, 1, args.max_vals, 'percent_tr', 1, to_train,
                      opt_params, schd_params, 1)
        layer_path = os.path.join(args.run_dir, task.name, "%s_best.th" % task.name)
        layer_state = torch.load(layer_path, map_location=device_mapping(args.cuda))
        model.load_state_dict(layer_state)

    # Evaluate: load the different task best models and evaluate them
    # TODO(Alex): put this in evaluate file
    all_results = {}

    if not best_epochs and args.load_epoch >= 0:
        epoch_to_load = args.load_epoch
    elif not best_epochs and not args.load_epoch:
        serialization_files = os.listdir(args.run_dir)
        model_checkpoints = [x for x in serialization_files if "model_state_epoch" in x]
        epoch_to_load = max([int(x.split("model_state_epoch_")[-1].strip(".th")) \
                             for x in model_checkpoints])
    else:
        epoch_to_load = -1

    #for task in [task.name for task in train_tasks] + ['micro', 'macro']:
    for task in ['macro']:
        log.info("Testing on %s..." % task)

        # Load best model
        load_idx = best_epochs[task] if best_epochs else epoch_to_load
        model_path = os.path.join(args.run_dir, "model_state_epoch_{}.th".format(load_idx))
        model_state = torch.load(model_path, map_location=device_mapping(args.cuda))
        model.load_state_dict(model_state)

        # Test evaluation and prediction
        # could just filter out tasks to get what i want...
        #tasks = [task for task in tasks if 'mnli' in task.name]
        te_results, te_preds = evaluate(model, tasks, iterator, cuda_device=args.cuda, split="test")
        val_results, _ = evaluate(model, tasks, iterator, cuda_device=args.cuda, split="val")

        if task == 'macro':
            all_results[task] = (val_results, te_results, model_path)
            for eval_task, task_preds in te_preds.items(): # write predictions for each task
                #if 'mnli' not in eval_task:
                #    continue
                idxs_and_preds = [(idx, pred) for pred, idx in zip(task_preds[0], task_preds[1])]
                idxs_and_preds.sort(key=lambda x: x[0])
                if 'mnli' in eval_task:
                    pred_map = {0: 'neutral', 1: 'entailment', 2: 'contradiction'}
                    with open(os.path.join(args.run_dir, "%s-m.tsv" % (eval_task)), 'w') as pred_fh:
                        pred_fh.write("index\tprediction\n")
                        split_idx = 0
                        for idx, pred in idxs_and_preds[:9796]:
                            pred = pred_map[pred]
                            pred_fh.write("%d\t%s\n" % (split_idx, pred))
                            split_idx += 1
                    with open(os.path.join(args.run_dir, "%s-mm.tsv" % (eval_task)), 'w') as pred_fh:
                        pred_fh.write("index\tprediction\n")
                        split_idx = 0
                        for idx, pred in idxs_and_preds[9796:9796+9847]:
                            pred = pred_map[pred]
                            pred_fh.write("%d\t%s\n" % (split_idx, pred))
                            split_idx += 1
                    with open(os.path.join(args.run_dir, "diagnostic.tsv"), 'w') as pred_fh:
                        pred_fh.write("index\tprediction\n")
                        split_idx = 0
                        for idx, pred in idxs_and_preds[9796+9847:]:
                            pred = pred_map[pred]
                            pred_fh.write("%d\t%s\n" % (split_idx, pred))
                            split_idx += 1
                else:
                    with open(os.path.join(args.run_dir, "%s.tsv" % (eval_task)), 'w') as pred_fh:
                        pred_fh.write("index\tprediction\n")
                        for idx, pred in idxs_and_preds:
                            if 'sts-b' in eval_task:
                                pred_fh.write("%d\t%.3f\n" % (idx, pred))
                            elif 'rte' in eval_task:
                                pred = 'entailment' if pred else 'not_entailment'
                                pred_fh.write('%d\t%s\n' % (idx, pred))
                            elif 'squad' in eval_task:
                                pred = 'entailment' if pred else 'not_entailment'
                                pred_fh.write('%d\t%s\n' % (idx, pred))
                            else:
                                pred_fh.write("%d\t%d\n" % (idx, pred))

            with open(os.path.join(args.exp_dir, "results.tsv"), 'a') as results_fh: # aggregate results easily
                run_name = args.run_dir.split('/')[-1]
                all_metrics_str = ', '.join(['%s: %.3f' % (metric, score) for \
                                            metric, score in val_results.items()])
                results_fh.write("%s\t%s\n" % (run_name, all_metrics_str))
    log.info("Done testing")

    # Dump everything to a pickle for posterity
    pkl.dump(all_results, open(os.path.join(args.run_dir, "results.pkl"), 'wb'))