def _restore_checkpoint(self):
        """
        Restores a model from a serialization_dir to the last saved checkpoint.
        This includes an epoch count and optimizer state, which is serialized separately
        from  model parameters. This function should only be used to continue training.
        """
        if not self._serialization_dir:
            raise ConfigurationError("serialization_dir not specified - cannot "
                                     "restore a model without a directory path.")

        logging.info(f'Recovering last model / training / task / metric states '
                     f'from {self._serialization_dir}...')

        model_path = os.path.join(self._serialization_dir, "model_state.th")
        training_state_path = os.path.join(self._serialization_dir, "training_state.th")
        task_state_path = os.path.join(self._serialization_dir, "task_state.th")
        metric_state_path = os.path.join(self._serialization_dir, "metric_state.th")

        model_state = torch.load(model_path, map_location=device_mapping(self._cuda_device))
        self._model.load_state_dict(model_state)

        task_states = torch.load(task_state_path, map_location=device_mapping(self._cuda_device))
        for task_name, task_state in task_states.items():
            self._task_infos[task_name]['total_batches_trained'] = task_state['total_batches_trained']
            self._task_infos[task_name]['optimizer'].load_state_dict(task_state['optimizer'])
            self._task_infos[task_name]['stopped'] = task_state['stopped']
            generator = self._task_infos[task_name]['tr_generator']
            for _ in itertools.islice(generator, task_state['total_batches_trained'] % \
                                      self._task_infos[task_name]['n_tr_batches']):
                pass

        metric_states = torch.load(metric_state_path, map_location=device_mapping(self._cuda_device))
        for metric_name, metric_state in metric_states.items():
            self._metric_infos[metric_name]['hist'] = metric_state['hist']
            self._metric_infos[metric_name]['stopped'] = metric_state['stopped']
            self._metric_infos[metric_name]['best'] = metric_state['best']

        training_state = torch.load(training_state_path, map_location=device_mapping(self._cuda_device))
        return training_state["epoch"], training_state["iter"], training_state["should_stop"]
Exemplo n.º 2
0
    def _restore_checkpoint(self):
        """
        Restores a model from a serialization_dir to the last saved checkpoint.
        This includes an epoch count and optimizer state, which is serialized separately
        from  model parameters. This function should only be used to continue training -
        if you wish to load a model for inference/load parts of a model into a new
        computation graph, you should use the native Pytorch functions:
        `` model.load_state_dict(torch.load("/path/to/model/weights.th"))``

        Returns
        -------
        epoch
            The epoch at which to resume training.
        """
        if not self._serialization_dir:
            raise ConfigurationError(
                "serialization_dir not specified - cannot "
                "restore a model without a directory path.")

        serialization_files = os.listdir(self._serialization_dir)
        model_checkpoints = [
            x for x in serialization_files if "model_state_epoch" in x
        ]
        epoch_to_load = max([int(x.split("model_state_epoch_")[-1].strip(".th")) \
                             for x in model_checkpoints])

        model_path = os.path.join(
            self._serialization_dir,
            "model_state_epoch_{}.th".format(epoch_to_load))
        training_state_path = os.path.join(
            self._serialization_dir,
            "training_state_epoch_{}.th".format(epoch_to_load))
        task_state_path = os.path.join(
            self._serialization_dir,
            "task_state_epoch_{}.th".format(epoch_to_load))
        metric_state_path = os.path.join(
            self._serialization_dir,
            "metric_state_epoch_{}.th".format(epoch_to_load))

        model_state = torch.load(model_path,
                                 map_location=device_mapping(
                                     self._cuda_device))
        self._model.load_state_dict(model_state)

        task_states = torch.load(task_state_path)
        for task_name, task_state in task_states.items():
            if task_name == 'global':
                continue
            self._task_infos[task_name]['total_batches_trained'] = task_state[
                'total_batches_trained']
            self._task_infos[task_name]['optimizer'].load_state_dict(
                task_state['optimizer'])
            for param, val in task_state['scheduler'].items():
                setattr(self._task_infos[task_name]['scheduler'], param, val)
            self._task_infos[task_name]['stopped'] = task_state['stopped']
            generator = self._task_infos[task_name]['tr_generator']
            for _ in itertools.islice(generator, task_state['total_batches_trained'] % \
                                      self._task_infos[task_name]['n_tr_batches']):
                pass
        if task_states['global']['optimizer'] is not None:
            self._g_optimizer.load_state_dict(
                task_states['global']['optimizer'])
        if task_states['global']['scheduler'] is not None:
            for param, val in task_states['global']['scheduler'].items():
                setattr(self._g_scheduler, param, val)

        metric_states = torch.load(metric_state_path)
        for metric_name, metric_state in metric_states.items():
            self._metric_infos[metric_name]['hist'] = metric_state['hist']
            self._metric_infos[metric_name]['stopped'] = metric_state[
                'stopped']
            self._metric_infos[metric_name]['best'] = metric_state['best']

        training_state = torch.load(training_state_path)
        return training_state["epoch"], training_state["should_stop"]
Exemplo n.º 3
0
def main(arguments):
    ''' Train or load a model. Evaluate on some tasks. '''
    parser = argparse.ArgumentParser(description='')

    # Logistics
    parser.add_argument('--cuda', help='-1 if no CUDA, else gpu id', type=int, default=0)
    parser.add_argument('--random_seed', help='random seed to use', type=int, default=19)

    # Paths and logging
    parser.add_argument('--log_file', help='file to log to', type=str, default='log.log')
    parser.add_argument('--exp_dir', help='directory containing shared preprocessing', type=str)
    parser.add_argument('--run_dir', help='directory for saving results, models, etc.', type=str)
    parser.add_argument('--word_embs_file', help='file containing word embs', type=str, default='')
    parser.add_argument('--preproc_file', help='file containing saved preprocessing stuff',
                        type=str, default='preproc.pkl')

    # Time saving flags
    parser.add_argument('--should_train', help='1 if should train model', type=int, default=1)
    parser.add_argument('--load_model', help='1 if load from checkpoint', type=int, default=1)
    parser.add_argument('--load_epoch', help='Force loading from a certain epoch', type=int,
                        default=-1)
    parser.add_argument('--load_tasks', help='1 if load tasks', type=int, default=1)
    parser.add_argument('--load_preproc', help='1 if load vocabulary', type=int, default=1)

    # Tasks and task-specific classifiers
    parser.add_argument('--train_tasks', help='comma separated list of tasks, or "all" or "none"',
                        type=str)
    parser.add_argument('--eval_tasks', help='list of additional tasks to train a classifier,' +
                        'then evaluate on', type=str, default='')
    parser.add_argument('--classifier', help='type of classifier to use', type=str,
                        default='log_reg', choices=['log_reg', 'mlp', 'fancy_mlp'])
    parser.add_argument('--classifier_hid_dim', help='hid dim of classifier', type=int, default=512)
    parser.add_argument('--classifier_dropout', help='classifier dropout', type=float, default=0.0)

    # Preprocessing options
    parser.add_argument('--max_seq_len', help='max sequence length', type=int, default=40)
    parser.add_argument('--max_word_v_size', help='max word vocab size', type=int, default=30000)

    # Embedding options
    parser.add_argument('--dropout_embs', help='dropout rate for embeddings', type=float, default=.2)
    parser.add_argument('--d_word', help='dimension of word embeddings', type=int, default=300)
    parser.add_argument('--glove', help='1 if use glove, else from scratch', type=int, default=1)
    parser.add_argument('--train_words', help='1 if make word embs trainable', type=int, default=0)
    parser.add_argument('--elmo', help='1 if use elmo', type=int, default=0)
    parser.add_argument('--deep_elmo', help='1 if use elmo post LSTM', type=int, default=0)
    parser.add_argument('--elmo_no_glove', help='1 if no glove, assuming elmo', type=int, default=0)
    parser.add_argument('--cove', help='1 if use cove', type=int, default=0)

    # Model options
    parser.add_argument('--pair_enc', help='type of pair encoder to use', type=str, default='simple',
                        choices=['simple', 'attn'])
    parser.add_argument('--d_hid', help='hidden dimension size', type=int, default=4096)
    parser.add_argument('--n_layers_enc', help='number of RNN layers', type=int, default=1)
    parser.add_argument('--n_layers_highway', help='num of highway layers', type=int, default=1)
    parser.add_argument('--dropout', help='dropout rate to use in training', type=float, default=.2)

    # Training options
    parser.add_argument('--no_tqdm', help='1 to turn off tqdm', type=int, default=0)
    parser.add_argument('--trainer_type', help='type of trainer', type=str,
                        choices=['sampling', 'mtl'], default='sampling')
    parser.add_argument('--shared_optimizer', help='1 to use same optimizer for all tasks',
                        type=int, default=1)
    parser.add_argument('--batch_size', help='batch size', type=int, default=64)
    parser.add_argument('--optimizer', help='optimizer to use', type=str, default='sgd')
    parser.add_argument('--n_epochs', help='n epochs to train for', type=int, default=10)
    parser.add_argument('--lr', help='starting learning rate', type=float, default=1.0)
    parser.add_argument('--min_lr', help='minimum learning rate', type=float, default=1e-5)
    parser.add_argument('--max_grad_norm', help='max grad norm', type=float, default=5.)
    parser.add_argument('--weight_decay', help='weight decay value', type=float, default=0.0)
    parser.add_argument('--task_patience', help='patience in decaying per task lr',
                        type=int, default=0)
    parser.add_argument('--scheduler_threshold', help='scheduler threshold',
                        type=float, default=0.0)
    parser.add_argument('--lr_decay_factor', help='lr decay factor when val score doesn\'t improve',
                        type=float, default=.5)

    # Multi-task training options
    parser.add_argument('--val_interval', help='Number of passes between validation checks',
                        type=int, default=10)
    parser.add_argument('--max_vals', help='Maximum number of validation checks', type=int,
                        default=100)
    parser.add_argument('--bpp_method', help='if using nonsampling trainer, ' +
                        'method for calculating number of batches per pass', type=str,
                        choices=['fixed', 'percent_tr', 'proportional_rank'], default='fixed')
    parser.add_argument('--bpp_base', help='If sampling or fixed bpp' +
                        'per pass, this is the bpp. If proportional, this ' +
                        'is the smallest number', type=int, default=10)
    parser.add_argument('--weighting_method', help='Weighting method for sampling', type=str,
                        choices=['uniform', 'proportional'], default='uniform')
    parser.add_argument('--scaling_method', help='method for scaling loss', type=str,
                        choices=['min', 'max', 'unit', 'none'], default='none')
    parser.add_argument('--patience', help='patience in early stopping', type=int, default=5)
    parser.add_argument('--task_ordering', help='Method for ordering tasks', type=str, default='given',
                        choices=['given', 'random', 'random_per_pass', 'small_to_large', 'large_to_small'])

    args = parser.parse_args(arguments)

    # Logistics #
    log.basicConfig(format='%(asctime)s: %(message)s', level=log.INFO, datefmt='%m/%d %I:%M:%S %p')
    log_file = os.path.join(args.run_dir, args.log_file)
    file_handler = log.FileHandler(log_file)
    log.getLogger().addHandler(file_handler)
    log.info(args)
    seed = random.randint(1, 10000) if args.random_seed < 0 else args.random_seed
    random.seed(seed)
    torch.manual_seed(seed)
    if args.cuda >= 0:
        log.info("Using GPU %d", args.cuda)
        torch.cuda.set_device(args.cuda)
        torch.cuda.manual_seed_all(seed)
    log.info("Using random seed %d", seed)

    # Load tasks #
    log.info("Loading tasks...")
    start_time = time.time()
    train_tasks, eval_tasks, vocab, word_embs = build_tasks(args)
    tasks = train_tasks + eval_tasks
    log.info('\tFinished loading tasks in %.3fs', time.time() - start_time)

    # Build model #
    log.info('Building model...')
    start_time = time.time()
    model = build_model(args, vocab, word_embs, tasks)
    log.info('\tFinished building model in %.3fs', time.time() - start_time)

    # Set up trainer #
    # TODO(Alex): move iterator creation
    iterator = BasicIterator(args.batch_size)
    #iterator = BucketIterator(sorting_keys=[("sentence1", "num_tokens")], batch_size=args.batch_size)
    trainer, train_params, opt_params, schd_params = build_trainer(args, args.trainer_type, model, iterator)

    # Train #
    if train_tasks and args.should_train:
        #to_train = [p for p in model.parameters() if p.requires_grad]
        to_train = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
        if args.trainer_type == 'mtl':
            best_epochs = trainer.train(train_tasks, args.task_ordering, args.val_interval,
                                        args.max_vals, args.bpp_method, args.bpp_base, to_train,
                                        opt_params, schd_params, args.load_model)
        elif args.trainer_type == 'sampling':
            if args.weighting_method == 'uniform':
                log.info("Sampling tasks uniformly")
            elif args.weighting_method == 'proportional':
                log.info("Sampling tasks proportional to number of training batches")

            if args.scaling_method == 'max':
                # divide by # batches, multiply by max # batches
                log.info("Scaling losses to largest task")
            elif args.scaling_method == 'min':
                # divide by # batches, multiply by fewest # batches
                log.info("Scaling losses to the smallest task")
            elif args.scaling_method == 'unit':
                log.info("Dividing losses by number of training batches")
            best_epochs = trainer.train(train_tasks, args.val_interval, args.bpp_base,
                                        args.weighting_method, args.scaling_method, to_train,
                                        opt_params, schd_params, args.shared_optimizer,
                                        args.load_model)
    else:
        log.info("Skipping training.")
        best_epochs = {}

    # train just the classifiers for eval tasks
    for task in eval_tasks:
        pred_layer = getattr(model, "%s_pred_layer" % task.name)
        to_train = pred_layer.parameters()
        trainer = MultiTaskTrainer.from_params(model, args.run_dir + '/%s/' % task.name,
                                               iterator, copy.deepcopy(train_params))
        trainer.train([task], args.task_ordering, 1, args.max_vals, 'percent_tr', 1, to_train,
                      opt_params, schd_params, 1)
        layer_path = os.path.join(args.run_dir, task.name, "%s_best.th" % task.name)
        layer_state = torch.load(layer_path, map_location=device_mapping(args.cuda))
        model.load_state_dict(layer_state)

    # Evaluate: load the different task best models and evaluate them
    # TODO(Alex): put this in evaluate file
    all_results = {}

    if not best_epochs and args.load_epoch >= 0:
        epoch_to_load = args.load_epoch
    elif not best_epochs and not args.load_epoch:
        serialization_files = os.listdir(args.run_dir)
        model_checkpoints = [x for x in serialization_files if "model_state_epoch" in x]
        epoch_to_load = max([int(x.split("model_state_epoch_")[-1].strip(".th")) \
                             for x in model_checkpoints])
    else:
        epoch_to_load = -1

    #for task in [task.name for task in train_tasks] + ['micro', 'macro']:
    for task in ['macro']:
        log.info("Testing on %s..." % task)

        # Load best model
        load_idx = best_epochs[task] if best_epochs else epoch_to_load
        model_path = os.path.join(args.run_dir, "model_state_epoch_{}.th".format(load_idx))
        model_state = torch.load(model_path, map_location=device_mapping(args.cuda))
        model.load_state_dict(model_state)

        # Test evaluation and prediction
        # could just filter out tasks to get what i want...
        #tasks = [task for task in tasks if 'mnli' in task.name]
        te_results, te_preds = evaluate(model, tasks, iterator, cuda_device=args.cuda, split="test")
        val_results, _ = evaluate(model, tasks, iterator, cuda_device=args.cuda, split="val")

        if task == 'macro':
            all_results[task] = (val_results, te_results, model_path)
            for eval_task, task_preds in te_preds.items(): # write predictions for each task
                #if 'mnli' not in eval_task:
                #    continue
                idxs_and_preds = [(idx, pred) for pred, idx in zip(task_preds[0], task_preds[1])]
                idxs_and_preds.sort(key=lambda x: x[0])
                if 'mnli' in eval_task:
                    pred_map = {0: 'neutral', 1: 'entailment', 2: 'contradiction'}
                    with open(os.path.join(args.run_dir, "%s-m.tsv" % (eval_task)), 'w') as pred_fh:
                        pred_fh.write("index\tprediction\n")
                        split_idx = 0
                        for idx, pred in idxs_and_preds[:9796]:
                            pred = pred_map[pred]
                            pred_fh.write("%d\t%s\n" % (split_idx, pred))
                            split_idx += 1
                    with open(os.path.join(args.run_dir, "%s-mm.tsv" % (eval_task)), 'w') as pred_fh:
                        pred_fh.write("index\tprediction\n")
                        split_idx = 0
                        for idx, pred in idxs_and_preds[9796:9796+9847]:
                            pred = pred_map[pred]
                            pred_fh.write("%d\t%s\n" % (split_idx, pred))
                            split_idx += 1
                    with open(os.path.join(args.run_dir, "diagnostic.tsv"), 'w') as pred_fh:
                        pred_fh.write("index\tprediction\n")
                        split_idx = 0
                        for idx, pred in idxs_and_preds[9796+9847:]:
                            pred = pred_map[pred]
                            pred_fh.write("%d\t%s\n" % (split_idx, pred))
                            split_idx += 1
                else:
                    with open(os.path.join(args.run_dir, "%s.tsv" % (eval_task)), 'w') as pred_fh:
                        pred_fh.write("index\tprediction\n")
                        for idx, pred in idxs_and_preds:
                            if 'sts-b' in eval_task:
                                pred_fh.write("%d\t%.3f\n" % (idx, pred))
                            elif 'rte' in eval_task:
                                pred = 'entailment' if pred else 'not_entailment'
                                pred_fh.write('%d\t%s\n' % (idx, pred))
                            elif 'squad' in eval_task:
                                pred = 'entailment' if pred else 'not_entailment'
                                pred_fh.write('%d\t%s\n' % (idx, pred))
                            else:
                                pred_fh.write("%d\t%d\n" % (idx, pred))

            with open(os.path.join(args.exp_dir, "results.tsv"), 'a') as results_fh: # aggregate results easily
                run_name = args.run_dir.split('/')[-1]
                all_metrics_str = ', '.join(['%s: %.3f' % (metric, score) for \
                                            metric, score in val_results.items()])
                results_fh.write("%s\t%s\n" % (run_name, all_metrics_str))
    log.info("Done testing")

    # Dump everything to a pickle for posterity
    pkl.dump(all_results, open(os.path.join(args.run_dir, "results.pkl"), 'wb'))
Exemplo n.º 4
0
def main(arguments):
    parser = argparse.ArgumentParser(description='')

    parser.add_argument(
        '--cuda',
        help='-1 if no CUDA, else gpu id (single gpu is enough)',
        type=int,
        default=0)
    parser.add_argument('--random_seed',
                        help='random seed to use',
                        type=int,
                        default=111)

    # Paths and logging
    parser.add_argument('--log_file',
                        help='file to log to',
                        type=str,
                        default='training.log')
    parser.add_argument('--store_root',
                        help='store root path',
                        type=str,
                        default='checkpoint')
    parser.add_argument('--store_name',
                        help='store name prefix for current experiment',
                        type=str,
                        default='sts')
    parser.add_argument('--suffix',
                        help='store name suffix for current experiment',
                        type=str,
                        default='')
    parser.add_argument('--word_embs_file',
                        help='file containing word embs',
                        type=str,
                        default='glove/glove.840B.300d.txt')

    # Training resuming flag
    parser.add_argument('--resume',
                        help='whether to resume training',
                        action='store_true',
                        default=False)

    # Tasks
    parser.add_argument('--task',
                        help='training and evaluation task',
                        type=str,
                        default='sts-b')

    # Preprocessing options
    parser.add_argument('--max_seq_len',
                        help='max sequence length',
                        type=int,
                        default=40)
    parser.add_argument('--max_word_v_size',
                        help='max word vocab size',
                        type=int,
                        default=30000)

    # Embedding options
    parser.add_argument('--dropout_embs',
                        help='dropout rate for embeddings',
                        type=float,
                        default=.2)
    parser.add_argument('--d_word',
                        help='dimension of word embeddings',
                        type=int,
                        default=300)
    parser.add_argument('--glove',
                        help='1 if use glove, else from scratch',
                        type=int,
                        default=1)
    parser.add_argument('--train_words',
                        help='1 if make word embs trainable',
                        type=int,
                        default=0)

    # Model options
    parser.add_argument('--d_hid',
                        help='hidden dimension size',
                        type=int,
                        default=1500)
    parser.add_argument('--n_layers_enc',
                        help='number of RNN layers',
                        type=int,
                        default=2)
    parser.add_argument('--n_layers_highway',
                        help='number of highway layers',
                        type=int,
                        default=0)
    parser.add_argument('--dropout',
                        help='dropout rate to use in training',
                        type=float,
                        default=0.2)

    # Training options
    parser.add_argument('--batch_size',
                        help='batch size',
                        type=int,
                        default=128)
    parser.add_argument('--optimizer',
                        help='optimizer to use',
                        type=str,
                        default='adam')
    parser.add_argument('--lr',
                        help='starting learning rate',
                        type=float,
                        default=1e-4)
    parser.add_argument(
        '--loss',
        type=str,
        default='mse',
        choices=['mse', 'l1', 'focal_l1', 'focal_mse', 'huber'])
    parser.add_argument('--huber_beta',
                        type=float,
                        default=0.3,
                        help='beta for huber loss')
    parser.add_argument('--max_grad_norm',
                        help='max grad norm',
                        type=float,
                        default=5.)
    parser.add_argument('--val_interval',
                        help='number of iterations between validation checks',
                        type=int,
                        default=400)
    parser.add_argument('--max_vals',
                        help='maximum number of validation checks',
                        type=int,
                        default=100)
    parser.add_argument('--patience',
                        help='patience for early stopping',
                        type=int,
                        default=10)

    # imbalanced related
    # LDS
    parser.add_argument('--lds',
                        action='store_true',
                        default=False,
                        help='whether to enable LDS')
    parser.add_argument('--lds_kernel',
                        type=str,
                        default='gaussian',
                        choices=['gaussian', 'triang', 'laplace'],
                        help='LDS kernel type')
    parser.add_argument('--lds_ks',
                        type=int,
                        default=5,
                        help='LDS kernel size: should be odd number')
    parser.add_argument('--lds_sigma',
                        type=float,
                        default=2,
                        help='LDS gaussian/laplace kernel sigma')
    # FDS
    parser.add_argument('--fds',
                        action='store_true',
                        default=False,
                        help='whether to enable FDS')
    parser.add_argument('--fds_kernel',
                        type=str,
                        default='gaussian',
                        choices=['gaussian', 'triang', 'laplace'],
                        help='FDS kernel type')
    parser.add_argument('--fds_ks',
                        type=int,
                        default=5,
                        help='FDS kernel size: should be odd number')
    parser.add_argument('--fds_sigma',
                        type=float,
                        default=2,
                        help='FDS gaussian/laplace kernel sigma')
    parser.add_argument('--start_update',
                        type=int,
                        default=0,
                        help='which epoch to start FDS updating')
    parser.add_argument(
        '--start_smooth',
        type=int,
        default=1,
        help='which epoch to start using FDS to smooth features')
    parser.add_argument('--bucket_num',
                        type=int,
                        default=50,
                        help='maximum bucket considered for FDS')
    parser.add_argument('--bucket_start',
                        type=int,
                        default=0,
                        help='minimum(starting) bucket for FDS')
    parser.add_argument('--fds_mmt',
                        type=float,
                        default=0.9,
                        help='FDS momentum')

    # re-weighting: SQRT_INV / INV
    parser.add_argument('--reweight',
                        type=str,
                        default='none',
                        choices=['none', 'sqrt_inv', 'inverse'],
                        help='cost-sensitive reweighting scheme')
    # two-stage training: RRT
    parser.add_argument(
        '--retrain_fc',
        action='store_true',
        default=False,
        help='whether to retrain last regression layer (regressor)')
    parser.add_argument(
        '--pretrained',
        type=str,
        default='',
        help='pretrained checkpoint file path to load backbone weights for RRT'
    )
    # evaluate only
    parser.add_argument('--evaluate',
                        action='store_true',
                        default=False,
                        help='evaluate only flag')
    parser.add_argument('--eval_model',
                        type=str,
                        default='',
                        help='the model to evaluate on; if not specified, '
                        'use the default best model in store_dir')

    args = parser.parse_args(arguments)

    os.makedirs(args.store_root, exist_ok=True)

    if not args.lds and args.reweight != 'none':
        args.store_name += f'_{args.reweight}'
    if args.lds:
        args.store_name += f'_lds_{args.lds_kernel[:3]}_{args.lds_ks}'
        if args.lds_kernel in ['gaussian', 'laplace']:
            args.store_name += f'_{args.lds_sigma}'
    if args.fds:
        args.store_name += f'_fds_{args.fds_kernel[:3]}_{args.fds_ks}'
        if args.fds_kernel in ['gaussian', 'laplace']:
            args.store_name += f'_{args.fds_sigma}'
        args.store_name += f'_{args.start_update}_{args.start_smooth}_{args.fds_mmt}'
    if args.retrain_fc:
        args.store_name += f'_retrain_fc'

    if args.loss == 'huber':
        args.store_name += f'_{args.loss}_beta_{args.huber_beta}'
    else:
        args.store_name += f'_{args.loss}'

    args.store_name += f'_seed_{args.random_seed}_valint_{args.val_interval}_patience_{args.patience}' \
                       f'_{args.optimizer}_{args.lr}_{args.batch_size}'
    args.store_name += f'_{args.suffix}' if len(args.suffix) else ''

    args.store_dir = os.path.join(args.store_root, args.store_name)

    if not args.evaluate and not args.resume:
        if os.path.exists(args.store_dir):
            if query_yes_no('overwrite previous folder: {} ?'.format(
                    args.store_dir)):
                shutil.rmtree(args.store_dir)
                print(args.store_dir + ' removed.\n')
            else:
                raise RuntimeError('Output folder {} already exists'.format(
                    args.store_dir))
        logging.info(f"===> Creating folder: {args.store_dir}")
        os.makedirs(args.store_dir)

    # Logistics
    logging.root.handlers = []
    if os.path.exists(args.store_dir):
        log_file = os.path.join(args.store_dir, args.log_file)
        logging.basicConfig(
            level=logging.INFO,
            format="%(asctime)s | %(message)s",
            handlers=[logging.FileHandler(log_file),
                      logging.StreamHandler()])
    else:
        logging.basicConfig(level=logging.INFO,
                            format="%(asctime)s | %(message)s",
                            handlers=[logging.StreamHandler()])
    logging.info(args)

    seed = random.randint(1,
                          10000) if args.random_seed < 0 else args.random_seed
    random.seed(seed)
    torch.manual_seed(seed)
    if args.cuda >= 0:
        logging.info("Using GPU %d", args.cuda)
        torch.cuda.set_device(args.cuda)
        torch.cuda.manual_seed_all(seed)
    logging.info("Using random seed %d", seed)

    # Load tasks
    logging.info("Loading tasks...")
    start_time = time.time()
    tasks, vocab, word_embs = build_tasks(args)
    logging.info('\tFinished loading tasks in %.3fs', time.time() - start_time)

    # Build model
    logging.info('Building model...')
    start_time = time.time()
    model = build_model(args, vocab, word_embs, tasks)
    logging.info('\tFinished building model in %.3fs',
                 time.time() - start_time)

    # Set up trainer
    iterator = BasicIterator(args.batch_size)
    trainer, train_params, opt_params = build_trainer(args, model, iterator)

    # Train
    if tasks and not args.evaluate:
        if args.retrain_fc and len(args.pretrained):
            model_path = args.pretrained
            assert os.path.isfile(
                model_path), f"No checkpoint found at '{model_path}'"
            model_state = torch.load(model_path,
                                     map_location=device_mapping(args.cuda))
            trainer._model = resume_checkpoint(trainer._model,
                                               model_state,
                                               backbone_only=True)
            logging.info(f'Pre-trained backbone weights loaded: {model_path}')
            logging.info('Retrain last regression layer only!')
            for name, param in trainer._model.named_parameters():
                if "sts-b_pred_layer" not in name:
                    param.requires_grad = False
            logging.info(
                f'Only optimize parameters: {[n for n, p in trainer._model.named_parameters() if p.requires_grad]}'
            )
            to_train = [(n, p) for n, p in trainer._model.named_parameters()
                        if p.requires_grad]
        else:
            to_train = [(n, p) for n, p in model.named_parameters()
                        if p.requires_grad]

        trainer.train(tasks, args.val_interval, to_train, opt_params,
                      args.resume)
    else:
        logging.info("Skipping training...")

    logging.info('Testing on test set...')
    model_path = os.path.join(
        args.store_dir,
        "model_state_best.th") if not len(args.eval_model) else args.eval_model
    assert os.path.isfile(model_path), f"No checkpoint found at '{model_path}'"
    logging.info(f'Evaluating {model_path}...')
    model_state = torch.load(model_path,
                             map_location=device_mapping(args.cuda))
    model = resume_checkpoint(model, model_state)
    te_preds, te_labels, _ = evaluate(model,
                                      tasks,
                                      iterator,
                                      cuda_device=args.cuda,
                                      split="test")
    if not len(args.eval_model):
        np.savez_compressed(os.path.join(args.store_dir,
                                         f"{args.store_name}.npz"),
                            preds=te_preds,
                            labels=te_labels)

    logging.info("Done testing.")
Exemplo n.º 5
0
def main(arguments):
    ''' Train or load a model. Evaluate on some tasks. '''
    parser = argparse.ArgumentParser(description='')

    # Logistics
    parser.add_argument('--cuda', help='-1 if no CUDA, else gpu id', type=int, default=0)
    parser.add_argument('--random_seed', help='random seed to use', type=int, default=19)

    # Paths and logging
    parser.add_argument('--log_file', help='file to log to', type=str, default='log.log')
    parser.add_argument('--exp_dir', help='directory containing shared preprocessing', type=str)
    parser.add_argument('--run_dir', help='directory for saving results, models, etc.', type=str)
    parser.add_argument('--word_embs_file', help='file containing word embs', type=str, default='')
    parser.add_argument('--preproc_file', help='file containing saved preprocessing stuff',
                        type=str, default='preproc.pkl')

    # Time saving flags
    parser.add_argument('--should_train', help='1 if should train model', type=int, default=1)
    parser.add_argument('--load_model', help='1 if load from checkpoint', type=int, default=1)
    parser.add_argument('--load_epoch', help='Force loading from a certain epoch', type=int,
                        default=-1)
    parser.add_argument('--load_tasks', help='1 if load tasks', type=int, default=1)
    parser.add_argument('--load_preproc', help='1 if load vocabulary', type=int, default=1)

    # Tasks and task-specific classifiers
    parser.add_argument('--train_tasks', help='comma separated list of tasks, or "all" or "none"',
                        type=str)
    parser.add_argument('--eval_tasks', help='list of additional tasks to train a classifier,' +
                        'then evaluate on', type=str, default='')
    parser.add_argument('--classifier', help='type of classifier to use', type=str,
                        default='log_reg', choices=['log_reg', 'mlp', 'fancy_mlp'])
    parser.add_argument('--classifier_hid_dim', help='hid dim of classifier', type=int, default=512)
    parser.add_argument('--classifier_dropout', help='classifier dropout', type=float, default=0.0)

    # Preprocessing options
    parser.add_argument('--max_seq_len', help='max sequence length', type=int, default=40)
    parser.add_argument('--max_word_v_size', help='max word vocab size', type=int, default=30000)

    # Embedding options
    parser.add_argument('--dropout_embs', help='dropout rate for embeddings', type=float, default=.2)
    parser.add_argument('--d_word', help='dimension of word embeddings', type=int, default=300)
    parser.add_argument('--glove', help='1 if use glove, else from scratch', type=int, default=1)
    parser.add_argument('--train_words', help='1 if make word embs trainable', type=int, default=0)
    parser.add_argument('--elmo', help='1 if use elmo', type=int, default=0)
    parser.add_argument('--deep_elmo', help='1 if use elmo post LSTM', type=int, default=0)
    parser.add_argument('--elmo_no_glove', help='1 if no glove, assuming elmo', type=int, default=0)
    parser.add_argument('--cove', help='1 if use cove', type=int, default=0)

    # Model options
    parser.add_argument('--pair_enc', help='type of pair encoder to use', type=str, default='simple',
                        choices=['simple', 'attn'])
    parser.add_argument('--d_hid', help='hidden dimension size', type=int, default=4096)
    parser.add_argument('--n_layers_enc', help='number of RNN layers', type=int, default=1)
    parser.add_argument('--n_layers_highway', help='num of highway layers', type=int, default=1)
    parser.add_argument('--dropout', help='dropout rate to use in training', type=float, default=.2)

    # Training options
    parser.add_argument('--no_tqdm', help='1 to turn off tqdm', type=int, default=0)
    parser.add_argument('--trainer_type', help='type of trainer', type=str,
                        choices=['sampling', 'mtl'], default='sampling')
    parser.add_argument('--shared_optimizer', help='1 to use same optimizer for all tasks',
                        type=int, default=1)
    parser.add_argument('--batch_size', help='batch size', type=int, default=64)
    parser.add_argument('--optimizer', help='optimizer to use', type=str, default='sgd')
    parser.add_argument('--n_epochs', help='n epochs to train for', type=int, default=10)
    parser.add_argument('--lr', help='starting learning rate', type=float, default=1.0)
    parser.add_argument('--min_lr', help='minimum learning rate', type=float, default=1e-5)
    parser.add_argument('--max_grad_norm', help='max grad norm', type=float, default=5.)
    parser.add_argument('--weight_decay', help='weight decay value', type=float, default=0.0)
    parser.add_argument('--task_patience', help='patience in decaying per task lr',
                        type=int, default=0)
    parser.add_argument('--scheduler_threshold', help='scheduler threshold',
                        type=float, default=0.0)
    parser.add_argument('--lr_decay_factor', help='lr decay factor when val score doesn\'t improve',
                        type=float, default=.5)

    # Multi-task training options
    parser.add_argument('--val_interval', help='Number of passes between validation checks',
                        type=int, default=10)
    parser.add_argument('--max_vals', help='Maximum number of validation checks', type=int,
                        default=100)
    parser.add_argument('--bpp_method', help='if using nonsampling trainer, ' +
                        'method for calculating number of batches per pass', type=str,
                        choices=['fixed', 'percent_tr', 'proportional_rank'], default='fixed')
    parser.add_argument('--bpp_base', help='If sampling or fixed bpp' +
                        'per pass, this is the bpp. If proportional, this ' +
                        'is the smallest number', type=int, default=10)
    parser.add_argument('--weighting_method', help='Weighting method for sampling', type=str,
                        choices=['uniform', 'proportional'], default='uniform')
    parser.add_argument('--scaling_method', help='method for scaling loss', type=str,
                        choices=['min', 'max', 'unit', 'none'], default='none')
    parser.add_argument('--patience', help='patience in early stopping', type=int, default=5)
    parser.add_argument('--task_ordering', help='Method for ordering tasks', type=str, default='given',
                        choices=['given', 'random', 'random_per_pass', 'small_to_large', 'large_to_small'])

    args = parser.parse_args(arguments)

    # Logistics #
    log.basicConfig(format='%(asctime)s: %(message)s', level=log.INFO, datefmt='%m/%d %I:%M:%S %p')
    log_file = os.path.join(args.run_dir, args.log_file)
    file_handler = log.FileHandler(log_file)
    log.getLogger().addHandler(file_handler)
    log.info(args)
    seed = random.randint(1, 10000) if args.random_seed < 0 else args.random_seed
    random.seed(seed)
    torch.manual_seed(seed)
    if args.cuda >= 0:
        log.info("Using GPU %d", args.cuda)
        torch.cuda.set_device(args.cuda)
        torch.cuda.manual_seed_all(seed)
    log.info("Using random seed %d", seed)

    # Load tasks #
    log.info("Loading tasks...")
    start_time = time.time()
    train_tasks, eval_tasks, vocab, word_embs = build_tasks(args)
    tasks = train_tasks + eval_tasks
    log.info('\tFinished loading tasks in %.3fs', time.time() - start_time)

    # Build model #
    log.info('Building model...')
    start_time = time.time()
    model = build_model(args, vocab, word_embs, tasks)
    log.info('\tFinished building model in %.3fs', time.time() - start_time)

    # Set up trainer #
    # TODO(Alex): move iterator creation
    iterator = BasicIterator(args.batch_size)
    #iterator = BucketIterator(sorting_keys=[("sentence1", "num_tokens")], batch_size=args.batch_size)
    trainer, train_params, opt_params, schd_params = build_trainer(args, args.trainer_type, model, iterator)

    # Train #
    if train_tasks and args.should_train:
        #to_train = [p for p in model.parameters() if p.requires_grad]
        to_train = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
        if args.trainer_type == 'mtl':
            best_epochs = trainer.train(train_tasks, args.task_ordering, args.val_interval,
                                        args.max_vals, args.bpp_method, args.bpp_base, to_train,
                                        opt_params, schd_params, args.load_model)
        elif args.trainer_type == 'sampling':
            if args.weighting_method == 'uniform':
                log.info("Sampling tasks uniformly")
            elif args.weighting_method == 'proportional':
                log.info("Sampling tasks proportional to number of training batches")

            if args.scaling_method == 'max':
                # divide by # batches, multiply by max # batches
                log.info("Scaling losses to largest task")
            elif args.scaling_method == 'min':
                # divide by # batches, multiply by fewest # batches
                log.info("Scaling losses to the smallest task")
            elif args.scaling_method == 'unit':
                log.info("Dividing losses by number of training batches")
            best_epochs = trainer.train(train_tasks, args.val_interval, args.bpp_base,
                                        args.weighting_method, args.scaling_method, to_train,
                                        opt_params, schd_params, args.shared_optimizer,
                                        args.load_model)
    else:
        log.info("Skipping training.")
        best_epochs = {}

    # train just the classifiers for eval tasks
    for task in eval_tasks:
        pred_layer = getattr(model, "%s_pred_layer" % task.name)
        to_train = pred_layer.parameters()
        trainer = MultiTaskTrainer.from_params(model, args.run_dir + '/%s/' % task.name,
                                               iterator, copy.deepcopy(train_params))
        trainer.train([task], args.task_ordering, 1, args.max_vals, 'percent_tr', 1, to_train,
                      opt_params, schd_params, 1)
        layer_path = os.path.join(args.run_dir, task.name, "%s_best.th" % task.name)
        layer_state = torch.load(layer_path, map_location=device_mapping(args.cuda))
        model.load_state_dict(layer_state)

    # Evaluate: load the different task best models and evaluate them
    # TODO(Alex): put this in evaluate file
    all_results = {}

    if not best_epochs and args.load_epoch >= 0:
        epoch_to_load = args.load_epoch
    elif not best_epochs and not args.load_epoch:
        serialization_files = os.listdir(args.run_dir)
        model_checkpoints = [x for x in serialization_files if "model_state_epoch" in x]
        epoch_to_load = max([int(x.split("model_state_epoch_")[-1].strip(".th")) \
                             for x in model_checkpoints])
    else:
        epoch_to_load = -1

    #for task in [task.name for task in train_tasks] + ['micro', 'macro']:
    for task in ['macro']:
        log.info("Testing on %s..." % task)

        # Load best model
        load_idx = best_epochs[task] if best_epochs else epoch_to_load
        model_path = os.path.join(args.run_dir, "model_state_epoch_{}.th".format(load_idx))
        model_state = torch.load(model_path, map_location=device_mapping(args.cuda))
        model.load_state_dict(model_state)

        # Test evaluation and prediction
        # could just filter out tasks to get what i want...
        #tasks = [task for task in tasks if 'mnli' in task.name]
        te_results, te_preds = evaluate(model, tasks, iterator, cuda_device=args.cuda, split="test")
        val_results, _ = evaluate(model, tasks, iterator, cuda_device=args.cuda, split="val")

        if task == 'macro':
            all_results[task] = (val_results, te_results, model_path)
            for eval_task, task_preds in te_preds.items(): # write predictions for each task
                #if 'mnli' not in eval_task:
                #    continue
                idxs_and_preds = [(idx, pred) for pred, idx in zip(task_preds[0], task_preds[1])]
                idxs_and_preds.sort(key=lambda x: x[0])
                if 'mnli' in eval_task:
                    pred_map = {0: 'neutral', 1: 'entailment', 2: 'contradiction'}
                    with open(os.path.join(args.run_dir, "%s-m.tsv" % (eval_task)), 'w') as pred_fh:
                        pred_fh.write("index\tprediction\n")
                        split_idx = 0
                        for idx, pred in idxs_and_preds[:9796]:
                            pred = pred_map[pred]
                            pred_fh.write("%d\t%s\n" % (split_idx, pred))
                            split_idx += 1
                    with open(os.path.join(args.run_dir, "%s-mm.tsv" % (eval_task)), 'w') as pred_fh:
                        pred_fh.write("index\tprediction\n")
                        split_idx = 0
                        for idx, pred in idxs_and_preds[9796:9796+9847]:
                            pred = pred_map[pred]
                            pred_fh.write("%d\t%s\n" % (split_idx, pred))
                            split_idx += 1
                    with open(os.path.join(args.run_dir, "diagnostic.tsv"), 'w') as pred_fh:
                        pred_fh.write("index\tprediction\n")
                        split_idx = 0
                        for idx, pred in idxs_and_preds[9796+9847:]:
                            pred = pred_map[pred]
                            pred_fh.write("%d\t%s\n" % (split_idx, pred))
                            split_idx += 1
                else:
                    with open(os.path.join(args.run_dir, "%s.tsv" % (eval_task)), 'w') as pred_fh:
                        pred_fh.write("index\tprediction\n")
                        for idx, pred in idxs_and_preds:
                            if 'sts-b' in eval_task:
                                pred_fh.write("%d\t%.3f\n" % (idx, pred))
                            elif 'rte' in eval_task:
                                pred = 'entailment' if pred else 'not_entailment'
                                pred_fh.write('%d\t%s\n' % (idx, pred))
                            elif 'squad' in eval_task:
                                pred = 'entailment' if pred else 'not_entailment'
                                pred_fh.write('%d\t%s\n' % (idx, pred))
                            else:
                                pred_fh.write("%d\t%d\n" % (idx, pred))

            with open(os.path.join(args.exp_dir, "results.tsv"), 'a') as results_fh: # aggregate results easily
                run_name = args.run_dir.split('/')[-1]
                all_metrics_str = ', '.join(['%s: %.3f' % (metric, score) for \
                                            metric, score in val_results.items()])
                results_fh.write("%s\t%s\n" % (run_name, all_metrics_str))
    log.info("Done testing")

    # Dump everything to a pickle for posterity
    pkl.dump(all_results, open(os.path.join(args.run_dir, "results.pkl"), 'wb'))
Exemplo n.º 6
0
    def _restore_checkpoint(self) -> Tuple[int, List[float]]:
        """
        Restores a model from a serialization_dir to the last saved checkpoint.
        This includes an epoch count and optimizer state, which is serialized separately
        from  model parameters. This function should only be used to continue training -
        if you wish to load a model for inference/load parts of a model into a new
        computation graph, you should use the native Pytorch functions:
        `` model.load_state_dict(torch.load("/path/to/model/weights.th"))``

        If ``self._serialization_dir`` does not exist or does not contain any checkpointed weights,
        this function will do nothing and return 0.

        Returns
        -------
        epoch: int
            The epoch at which to resume training, which should be one after the epoch
            in the saved training state.
        """
        have_checkpoint = (self._serialization_dir is not None and any(
            "model_state_epoch_" in x
            for x in os.listdir(self._serialization_dir)))

        if not have_checkpoint:
            # No checkpoint to restore, start at 0
            return 0, []

        serialization_files = os.listdir(self._serialization_dir)
        model_checkpoints = [
            x for x in serialization_files if "model_state_epoch" in x
        ]
        # Get the last checkpoint file.  Epochs are specified as either an
        # int (for end of epoch files) or with epoch and timestamp for
        # within epoch checkpoints, e.g. 5.2018-02-02-15-33-42
        found_epochs = [
            # pylint: disable=anomalous-backslash-in-string
            re.search("model_state_epoch_([0-9\.\-]+)\.th", x).group(1)
            for x in model_checkpoints
        ]
        int_epochs = []
        for epoch in found_epochs:
            pieces = epoch.split('.')
            if len(pieces) == 1:
                # Just a single epoch without timestamp
                int_epochs.append([int(pieces[0]), 0])
            else:
                # has a timestamp
                int_epochs.append([int(pieces[0]), pieces[1]])
        last_epoch = sorted(int_epochs, reverse=True)[0]
        if last_epoch[1] == 0:
            epoch_to_load = str(last_epoch[0])
        else:
            epoch_to_load = '{0}.{1}'.format(last_epoch[0], last_epoch[1])

        model_path = os.path.join(
            self._serialization_dir,
            "model_state_epoch_{}.th".format(epoch_to_load))
        training_state_path = os.path.join(
            self._serialization_dir,
            "training_state_epoch_{}.th".format(epoch_to_load))

        # Load the parameters onto CPU, then transfer to GPU.
        # This avoids potential OOM on GPU for large models that
        # load parameters onto GPU then make a new GPU copy into the parameter
        # buffer. The GPU transfer happens implicitly in load_state_dict.
        model_state = torch.load(model_path,
                                 map_location=util.device_mapping(-1))
        training_state = torch.load(training_state_path,
                                    map_location=util.device_mapping(-1))
        self._model.load_state_dict(model_state)
        self._optimizer.load_state_dict(training_state["optimizer"])
        # move_optimizer_to_cuda(self._optimizer)

        # We didn't used to save `validation_metric_per_epoch`, so we can't assume
        # that it's part of the trainer state. If it's not there, an empty list is all
        # we can do.
        if "val_metric_per_epoch" not in training_state:
            logger.warning(
                "trainer state `val_metric_per_epoch` not found, using empty list"
            )
            val_metric_per_epoch = []
        else:
            val_metric_per_epoch = training_state["val_metric_per_epoch"]

        if isinstance(training_state["epoch"], int):
            epoch_to_return = training_state["epoch"] + 1
        else:
            epoch_to_return = int(training_state["epoch"].split('.')[0]) + 1

        # For older checkpoints with batch_num_total missing, default to old behavior where
        # it is unchanged.
        batch_num_total = training_state.get('batch_num_total')
        if batch_num_total is not None:
            self._batch_num_total = batch_num_total

        return epoch_to_return, val_metric_per_epoch