Ejemplo n.º 1
0
    def __init__(self, opt, shared=None):
        # initialize defaults first
        super().__init__(opt, shared)
        if not shared:
            # this is not a shared instance of this class, so do full
            # initialization. if shared is set, only set up shared members.
            saved_state = None
            if opt.get('model_file') and os.path.isfile(opt['model_file']):
                # load model parameters if available
                print('Loading existing model params from ' +
                      opt['model_file'])
                new_opt, saved_state = self.load(opt['model_file'])
                # override options with stored ones
                opt = self._override_opt(new_opt)

            self.args = OptWrapper(opt)
            self.fairseq_dict = _make_fairseq_dict(DictionaryAgent(opt))
            self.id = 'Fairseq'
            self.truncate = opt['truncate'] if opt['truncate'] > 0 else None

            self.EOS = self.fairseq_dict[self.fairseq_dict.eos()]
            self.EOS_TENSOR = (torch.LongTensor(1, 1).fill_(
                self.fairseq_dict.eos()))
            self.NULL_IDX = self.fairseq_dict.pad()

            encoder = fconv.Encoder(self.fairseq_dict,
                                    embed_dim=self.args.encoder_embed_dim,
                                    convolutions=eval(
                                        self.args.encoder_layers),
                                    dropout=self.args.dropout,
                                    max_positions=self.args.max_positions)
            decoder = fconv.Decoder(
                self.fairseq_dict,
                embed_dim=self.args.decoder_embed_dim,
                convolutions=eval(self.args.decoder_layers),
                out_embed_dim=self.args.decoder_out_embed_dim,
                attention=eval(self.args.decoder_attention),
                dropout=self.args.dropout,
                max_positions=self.args.max_positions)
            self.model = fconv.FConvModel(encoder, decoder)

            # from fairseq's build_criterion()
            if self.args.label_smoothing > 0:
                self.criterion = criterions.LabelSmoothedCrossEntropyCriterion(
                    self.args.label_smoothing, self.NULL_IDX)
            else:
                self.criterion = criterions.CrossEntropyCriterion(
                    self.NULL_IDX)

            self.trainer = MultiprocessingTrainer(self.args, self.model,
                                                  self.criterion)
            if saved_state is not None:
                self.set_states(saved_state)
        self.reset()
Ejemplo n.º 2
0
    def __init__(self, opt, shared=None):
        # initialize defaults first
        super().__init__(opt, shared)
        if not shared:
            # this is not a shared instance of this class, so do full
            # initialization. if shared is set, only set up shared members.

            if opt.get('model_file') and os.path.isfile(opt['model_file']):
                # load model parameters if available
                print('Loading existing model params from ' +
                      opt['model_file'])
                new_opt, self.saved_state = self.load(opt['model_file'])
                # override options with stored ones
                opt = self._override_opt(new_opt)

            self.args = OptWrapper(opt)
            self.fairseq_dict = _make_fairseq_dict(DictionaryAgent(opt))
            self.id = 'Fairseq'

            self.EOS = self.fairseq_dict[self.fairseq_dict.eos()]
            self.NULL_IDX = self.fairseq_dict.pad()

            encoder = fconv.Encoder(
                len(self.fairseq_dict),
                embed_dim=self.args.encoder_embed_dim,
                convolutions=eval(self.args.encoder_layers),
                dropout=self.args.dropout,
                padding_idx=self.NULL_IDX,
                max_positions=self.args.max_positions)
            decoder = fconv.Decoder(
                len(self.fairseq_dict),
                embed_dim=self.args.decoder_embed_dim,
                convolutions=eval(self.args.decoder_layers),
                out_embed_dim=self.args.decoder_out_embed_dim,
                attention=eval(self.args.decoder_attention),
                dropout=self.args.dropout,
                padding_idx=self.NULL_IDX,
                max_positions=self.args.max_positions)
            self.model = fconv.FConvModel(encoder, decoder, self.NULL_IDX)

            # from fairseq's build_criterion()
            if self.args.label_smoothing > 0:
                self.criterion = criterions.LabelSmoothedCrossEntropyCriterion(
                    self.args.label_smoothing, self.NULL_IDX)
            else:
                self.criterion = criterions.CrossEntropyCriterion(
                    self.NULL_IDX)

            self.trainer = MultiprocessingTrainer(self.args, self.model)
            if hasattr(self, 'saved_state'):
                self.set_states(self.saved_state)

        self.reset()
Ejemplo n.º 3
0
def main():
    parser = options.get_parser('Trainer')
    dataset_args = options.add_dataset_args(parser)
    dataset_args.add_argument('--max-tokens',
                              default=0,
                              type=int,
                              metavar='N',
                              help='maximum number of tokens in a batch')
    dataset_args.add_argument('--batch-size',
                              default=32,
                              type=int,
                              metavar='N',
                              help='batch size')
    dataset_args.add_argument('--test-batch-size',
                              default=32,
                              type=int,
                              metavar='N',
                              help='batch size for test set')
    dataset_args.add_argument('--valid-batch-size',
                              default=32,
                              type=int,
                              metavar='N',
                              help='batch size for validation set')
    dataset_args.add_argument(
        '--train-subset',
        default='train',
        metavar='SPLIT',
        choices=['train', 'valid', 'test'],
        help='data subset to use for training (train, valid, test)')
    dataset_args.add_argument(
        '--valid-subset',
        default='valid',
        metavar='SPLIT',
        help='comma separated list ofdata subsets '
        ' to use for validation (train, valid, valid1,test, test1)')
    dataset_args.add_argument('--test-subset',
                              default='test',
                              metavar='SPLIT',
                              help='comma separated list ofdata subset '
                              'to use for testing (train, valid, test)')
    dataset_args.add_argument(
        '--valid-script',
        nargs='+',
        metavar='PATH',
        help='path to external validation script (optional).')

    options.add_optimization_args(parser)
    options.add_checkpoint_args(parser)
    options.add_model_args(parser)

    args = utils.parse_args_and_arch(parser)
    print(args)

    if args.no_progress_bar:
        progress_bar.enabled = False
        progress_bar.print_interval = args.log_interval

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    torch.manual_seed(args.seed)

    # Setting args.max_tokens to infinity(same as setting to None)
    if args.max_tokens == 0:
        args.max_tokens = None

    # Load dataset
    dataset = data.load_with_check(args.data, args.source_lang,
                                   args.target_lang)
    if args.source_lang is None or args.target_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.source_lang, args.target_lang = dataset.src, dataset.dst

    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))
    for split in dataset.splits:
        print('| {} {} {} examples'.format(args.data, split,
                                           len(dataset.splits[split])))

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    num_gpus = torch.cuda.device_count()

    print('| using {} GPUs (with max tokens per GPU = {})'.format(
        num_gpus, args.max_tokens))

    # Build model
    print('| model {}'.format(args.arch))
    model = utils.build_model(args, dataset)
    criterion = utils.build_criterion(args, dataset)

    # Start multiprocessing
    trainer = MultiprocessingTrainer(args, model)

    # Load the latest checkpoint if one is available
    epoch, batch_offset = trainer.load_checkpoint(
        os.path.join(args.save_dir, args.restore_file))

    # Train until the learning rate gets too small
    val_loss = None
    max_epoch = args.max_epoch or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    while lr > args.min_lr and epoch <= max_epoch:
        # train for one epoch
        train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus)

        # evaluate on validate set
        for k, subset in enumerate(args.valid_subset.split(',')):
            val_loss = validate(args, epoch, trainer, criterion, dataset,
                                subset, num_gpus)
            if k == 0:
                if not args.no_save:
                    # save checkpoint
                    trainer.save_checkpoint(
                        args,
                        epoch,
                        0,
                        val_loss,
                        validation_script=args.valid_script)

                # only use first validation loss to update the learning schedule
                lr = trainer.lr_step(val_loss, epoch)

        epoch += 1
        batch_offset = 0
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))

    # Generate on test set and compute BLEU score
    for beam in [1, 5, 10, 20]:
        for subset in args.test_subset.split(','):
            scorer = score_test(args,
                                trainer.get_model(),
                                dataset,
                                subset,
                                beam,
                                cuda_device=(0 if num_gpus > 0 else None))
            print('| Test on {} with beam={}: {}'.format(
                subset, beam, scorer.result_string()))

    # Stop multiprocessing
    trainer.stop()
Ejemplo n.º 4
0
class FairseqAgent(Agent):
    """Agent which takes an input sequence and produces an output sequence.

    For more information, see Convolutional Sequence to Sequence Learning
     `(Gehring et al. 2017) <https://arxiv.org/abs/1705.03122>`_.
    """
    @staticmethod
    def add_cmdline_args(argparser):
        """Add command-line arguments specifically for this agent."""
        DictionaryAgent.add_cmdline_args(argparser)
        agent = argparser.add_argument_group('Fairseq Arguments')

        agent.add_argument('--max-positions',
                           default=1024,
                           type=int,
                           metavar='N',
                           help='max number of tokens in the sequence')
        agent.add_argument('--seed',
                           default=1,
                           type=int,
                           metavar='N',
                           help='pseudo random number generator seed')
        agent.add_argument('--lr',
                           '--learning-rate',
                           default=0.25,
                           type=float,
                           metavar='LR',
                           help='initial learning rate')
        agent.add_argument('--momentum',
                           default=0.99,
                           type=float,
                           metavar='M',
                           help='momentum factor')
        agent.add_argument('--weight-decay',
                           '--wd',
                           default=0.0,
                           type=float,
                           metavar='WD',
                           help='weight decay')
        agent.add_argument('--force-anneal',
                           '--fa',
                           default=0,
                           type=int,
                           metavar='N',
                           help='force annealing at specified epoch')
        agent.add_argument('--beam',
                           default=5,
                           type=int,
                           metavar='N',
                           help='beam size')
        agent.add_argument(
            '--no-early-stop',
            action='store_true',
            help=('continue searching even after finalizing k=beam '
                  'hypotheses; this is more correct, but increases '
                  'generation time by 50%%'))
        agent.add_argument('--unnormalized',
                           action='store_true',
                           help='compare unnormalized hypothesis scores')

        agent.add_argument(
            '--lenpen',
            default=1,
            type=float,
            help=
            'length penalty: <1.0 favors shorter, >1.0 favors longer sentences'
        )

        agent.add_argument('--clip-norm',
                           default=25,
                           type=float,
                           metavar='NORM',
                           help='clip threshold of gradients')

        agent.add_argument('--arch',
                           '-a',
                           default='fconv',
                           metavar='ARCH',
                           choices=models.arch_model_map.keys(),
                           help='model architecture ({})'.format(', '.join(
                               models.arch_model_map.keys())))
        agent.add_argument('--encoder-embed-dim',
                           type=int,
                           metavar='N',
                           help='encoder embedding dimension')
        agent.add_argument('--encoder-layers',
                           type=str,
                           metavar='EXPR',
                           help='encoder layers [(dim, kernel_size), ...]')
        agent.add_argument('--decoder-embed-dim',
                           type=int,
                           metavar='N',
                           help='decoder embedding dimension')
        agent.add_argument('--decoder-layers',
                           type=str,
                           metavar='EXPR',
                           help='decoder layers [(dim, kernel_size), ...]')
        agent.add_argument('--decoder-out-embed-dim',
                           type=int,
                           metavar='N',
                           help='decoder output embedding dimension')
        agent.add_argument('--decoder-attention',
                           type=str,
                           metavar='EXPR',
                           help='decoder attention [True, ...]')

        # These arguments have default values independent of the model:
        agent.add_argument('--dropout',
                           default=0.1,
                           type=float,
                           metavar='D',
                           help='dropout probability')
        agent.add_argument(
            '--label-smoothing',
            default=0,
            type=float,
            metavar='D',
            help='epsilon for label smoothing, 0 means no label smoothing')

    def __init__(self, opt, shared=None):
        # initialize defaults first
        super().__init__(opt, shared)
        if not shared:
            # this is not a shared instance of this class, so do full
            # initialization. if shared is set, only set up shared members.

            if opt.get('model_file') and os.path.isfile(opt['model_file']):
                # load model parameters if available
                print('Loading existing model params from ' +
                      opt['model_file'])
                new_opt, self.saved_state = self.load(opt['model_file'])
                # override options with stored ones
                opt = self._override_opt(new_opt)

            self.args = OptWrapper(opt)
            self.fairseq_dict = _make_fairseq_dict(DictionaryAgent(opt))
            self.id = 'Fairseq'

            self.EOS = self.fairseq_dict[self.fairseq_dict.eos()]
            self.NULL_IDX = self.fairseq_dict.pad()

            encoder = fconv.Encoder(len(self.fairseq_dict),
                                    embed_dim=self.args.encoder_embed_dim,
                                    convolutions=eval(
                                        self.args.encoder_layers),
                                    dropout=self.args.dropout,
                                    padding_idx=self.NULL_IDX,
                                    max_positions=self.args.max_positions)
            decoder = fconv.Decoder(
                len(self.fairseq_dict),
                embed_dim=self.args.decoder_embed_dim,
                convolutions=eval(self.args.decoder_layers),
                out_embed_dim=self.args.decoder_out_embed_dim,
                attention=eval(self.args.decoder_attention),
                dropout=self.args.dropout,
                padding_idx=self.NULL_IDX,
                max_positions=self.args.max_positions)
            self.model = fconv.FConvModel(encoder, decoder, self.NULL_IDX)

            # from fairseq's build_criterion()
            if self.args.label_smoothing > 0:
                self.criterion = criterions.LabelSmoothedCrossEntropyCriterion(
                    self.args.label_smoothing, self.NULL_IDX)
            else:
                self.criterion = criterions.CrossEntropyCriterion(
                    self.NULL_IDX)

            self.trainer = MultiprocessingTrainer(self.args, self.model)
            if hasattr(self, 'saved_state'):
                self.set_states(self.saved_state)

        self.reset()

    def _override_opt(self, new_opt):
        """Set overridable opts from loaded opt file.

        Print out each added key and each overriden key.
        Only override args specific to the model.
        """
        model_args = {
            'arch',
            'encoder-embed-dim',
            'encoder-layers',
            'decoder-embed-dim',
            'decoder-layers',
            'decoder-out-embed-dim',
            'decoder-attention',
        }

        for k, v in new_opt.items():
            if k not in model_args:
                # skip non-model args
                continue
            if k not in self.opt:
                print('Adding new option [ {k}: {v} ]'.format(k=k, v=v))
            elif self.opt[k] != v:
                print('Overriding option [ {k}: {old} => {v}]'.format(
                    k=k, old=self.opt[k], v=v))
            self.opt[k] = v
        return self.opt

    def reset(self):
        """Reset observation and episode_done."""
        self.observation = None
        self.episode_done = True

    def observe(self, observation):
        # shallow copy observation (deep copy can be expensive)
        observation = observation.copy()
        if not self.episode_done:
            # if the last example wasn't the end of an episode, then we need to
            # recall what was said in that example
            prev_dialogue = self.observation['text']
            observation['text'] = prev_dialogue + '\n' + observation['text']
        self.observation = observation
        self.episode_done = observation['episode_done']
        return observation

    def act(self):
        # call batch_act with this batch of one
        return self.batch_act([self.observation])[0]

    def batch_act(self, observations):
        batchsize = len(observations)
        # initialize a table of replies with this agent's id
        batch_reply = [{'id': self.getID()} for _ in range(batchsize)]

        # convert the observations into batches of inputs and targets
        # valid_inds tells us the indices of all valid examples
        # e.g. for input [{}, {'text': 'hello'}, {}, {}], valid_inds is [1]
        # since the other three elements had no 'text' field
        xs, ys, valid_inds = self.batchify(observations)

        if len(xs) == 0:
            # no valid examples, just return the empty responses we set up
            return batch_reply

        # produce predictions if testing; otherwise, train

        if ys is None:
            predictions = self._generate(self.args, xs)
            for i in range(len(predictions)):
                # map the predictions back to non-empty examples in the batch
                batch_reply[valid_inds[i]]['text'] = predictions[i]
        else:
            self._train(xs, ys)

        return batch_reply

    def parse(self, string):
        return [self.fairseq_dict.index(word) for word in string.split(' ')]

    def batchify(self, observations):
        """Convert a list of observations into input & target tensors."""
        # valid examples
        exs = [ex for ex in observations if 'text' in ex]
        # the indices of the valid (non-empty) tensors
        valid_inds = [i for i, ex in enumerate(observations) if 'text' in ex]

        # set up the input tensors
        batchsize = len(exs)
        # tokenize the text
        parsed = [self.parse(ex['text']) for ex in exs]
        max_x_len = max([len(x) for x in parsed])
        xs = torch.LongTensor(batchsize,
                              max_x_len).fill_(self.fairseq_dict.pad())
        # pack the data to the right side of the tensor for this model
        for i, x in enumerate(parsed):
            offset = max_x_len - len(x)
            for j, idx in enumerate(x):
                xs[i][j + offset] = idx
        xs = xs.cuda(async=True)
        # set up the target tensors
        ys = None
        if 'labels' in exs[0]:
            # randomly select one of the labels to update on, if multiple
            # append EOS to each label
            labels = [
                random.choice(ex['labels']) + ' ' + self.EOS for ex in exs
            ]
            parsed = [self.parse(y) for y in labels]
            max_y_len = max(len(y) for y in parsed)
            ys = torch.LongTensor(batchsize,
                                  max_y_len).fill_(self.fairseq_dict.pad())
            for i, y in enumerate(parsed):
                for j, idx in enumerate(y):
                    ys[i][j] = idx
            ys = ys.cuda(async=True)
        return xs, ys, valid_inds

    def _positions_for_tokens(self, tokens):
        start = self.fairseq_dict.pad() + 1
        size = tokens.size()
        positions = torch.LongTensor(size).fill_(self.fairseq_dict.pad())
        for i in range(size[0]):
            nonpad = 0
            for j in range(size[1]):
                if (tokens[i][j] != self.fairseq_dict.pad()):
                    positions[i][j] = start + nonpad
                    nonpad += 1
        positions = positions.cuda(async=True)
        return positions

    def _right_shifted_ys(self, ys):
        result = torch.LongTensor(ys.size())
        result[:, 0] = self.fairseq_dict.index(self.EOS)
        result[:, 1:] = ys[:, :-1]
        return result

    def _generate(self, opt, src_tokens):
        translator = SequenceGenerator([self.trainer.get_model()],
                                       self.fairseq_dict,
                                       beam_size=opt.beam,
                                       stop_early=(not opt.no_early_stop),
                                       normalize_scores=(not opt.unnormalized),
                                       len_penalty=opt.lenpen)
        translator.cuda()
        tokens = src_tokens
        translations = translator.generate(
            Variable(tokens), Variable(self._positions_for_tokens(tokens)))
        results = [t[0] for t in translations]
        output_lines = [[] for _ in range(len(results))]
        for i in range(len(results)):
            output_lines[i] = ' '.join(self.fairseq_dict[idx]
                                       for idx in results[i]['tokens'][:-1])
        return output_lines

    def _train(self, xs, ys=None):
        """Produce a prediction from our model. Update the model using the
        targets if available.
        """
        if ys is not None:
            sample = {
                'src_tokens': xs,
                'input_tokens': self._right_shifted_ys(ys),
                'target': ys,
                'id': None
            }
            sample['ntokens'] = sum(len(t) for t in sample['target'])
            sample['src_positions'] = self._positions_for_tokens(
                sample['src_tokens'])
            sample['input_positions'] = self._positions_for_tokens(
                sample['input_tokens'])
            self.trainer.train_step([sample], self.criterion)

    def save(self, path=None):
        path = self.opt.get('model_file', None) if path is None else path
        if path and hasattr(self, 'trainer'):
            model = {}
            model['state_dict'] = self.trainer.get_model().state_dict()
            model['opt'] = self.opt
            with open(path, 'wb') as write:
                torch.save(model, write)

    def shutdown(self):
        """Save the state of the model when shutdown."""
        path = self.opt.get('model_file', None)
        if path is not None:
            self.save(path + '.shutdown_state')
        super().shutdown()

    def load(self, path):
        """Return opt and model states."""
        with open(path, 'rb') as read:
            model = torch.load(read)
        return model['opt'], model['state_dict']

    def set_states(self, state_dict):
        """Set the state dict of the model from saved states."""
        self.trainer.get_model().load_state_dict(state_dict)
Ejemplo n.º 5
0
def main():
    parser = options.get_parser('Trainer')
    dataset_args = options.add_dataset_args(parser)
    dataset_args.add_argument('--max-tokens',
                              default=6000,
                              type=int,
                              metavar='N',
                              help='maximum number of tokens in a batch')
    dataset_args.add_argument(
        '--train-subset',
        default='train',
        metavar='SPLIT',
        choices=['train', 'valid', 'test'],
        help='data subset to use for training (train, valid, test)')
    dataset_args.add_argument(
        '--valid-subset',
        default='valid',
        metavar='SPLIT',
        help='comma separated list ofdata subsets '
        ' to use for validation (train, valid, valid1,test, test1)')
    options.add_optimization_args(parser)
    options.add_checkpoint_args(parser)
    options.add_model_args(parser)

    args = utils.parse_args_and_arch(parser)
    print(args)

    if args.no_progress_bar:
        progress_bar.enabled = False
        progress_bar.print_interval = args.log_interval

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    torch.manual_seed(args.seed)

    # Load dataset
    dataset = data.load_with_check(args.data, ['train', 'valid'],
                                   args.source_lang, args.target_lang)
    if args.source_lang is None or args.target_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.source_lang, args.target_lang = dataset.src, dataset.dst

    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))
    for split in ['train', 'valid']:
        print('| {} {} {} examples'.format(args.data, split,
                                           len(dataset.splits[split])))

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    num_gpus = torch.cuda.device_count()

    print('| using {} GPUs (with max tokens per GPU = {})'.format(
        num_gpus, args.max_tokens))

    # Build model and criterion
    model = utils.build_model(args, dataset.src_dict, dataset.dst_dict)
    criterion = utils.build_criterion(args, dataset.src_dict, dataset.dst_dict)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))

    # Start multiprocessing
    trainer = MultiprocessingTrainer(args, model, criterion)

    # Load the latest checkpoint if one is available
    checkpoint_path = os.path.join(args.save_dir, args.restore_file)
    extra_state = trainer.load_checkpoint(checkpoint_path)
    if extra_state is not None:
        epoch = extra_state['epoch']
        batch_offset = extra_state['batch_offset']
        print('| loaded checkpoint {} (epoch {})'.format(
            checkpoint_path, epoch))
        if batch_offset == 0:
            epoch += 1
    else:
        epoch, batch_offset = 1, 0

    # Train until the learning rate gets too small
    val_loss = None
    max_epoch = args.max_epoch or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    while lr > args.min_lr and epoch <= max_epoch:
        # train for one epoch
        train(args, epoch, batch_offset, trainer, dataset, num_gpus)

        # evaluate on validate set
        for k, subset in enumerate(args.valid_subset.split(',')):
            val_loss = validate(args, epoch, trainer, dataset, subset,
                                num_gpus)
            if k == 0:
                if not args.no_save:
                    # save checkpoint
                    save_checkpoint(trainer, args, epoch, 0, val_loss)
                # only use first validation loss to update the learning schedule
                lr = trainer.lr_step(val_loss, epoch)

        epoch += 1
        batch_offset = 0
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))

    # Stop multiprocessing
    trainer.stop()
Ejemplo n.º 6
0
def main():
    parser = options.get_parser('Trainer')
    dataset_args = options.add_dataset_args(parser)
    dataset_args.add_argument('--max-tokens',
                              default=6000,
                              type=int,
                              metavar='N',
                              help='maximum number of tokens in a batch')
    dataset_args.add_argument('--max-sentences',
                              type=int,
                              metavar='N',
                              help='maximum number of sentences in a batch')
    dataset_args.add_argument(
        '--train-subset',
        default='train',
        metavar='SPLIT',
        choices=['train', 'valid', 'test'],
        help='data subset to use for training (train, valid, test)')
    dataset_args.add_argument(
        '--valid-subset',
        default='valid',
        metavar='SPLIT',
        help='comma separated list of data subsets '
        ' to use for validation (train, valid, valid1,test, test1)')
    dataset_args.add_argument(
        '--max-sentences-valid',
        type=int,
        metavar='N',
        help='maximum number of sentences in a validation batch')
    options.add_optimization_args(parser)
    options.add_checkpoint_args(parser)
    options.add_model_args(parser)

    args = utils.parse_args_and_arch(parser)

    if args.no_progress_bar and args.log_format is None:
        args.log_format = 'simple'

    if args.max_sentences_valid is None:
        args.max_sentences_valid = args.max_sentences

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    torch.manual_seed(args.seed)

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        dataset = data.load_dataset(args.data, splits, args.source_lang,
                                    args.target_lang)
    else:
        dataset = data.load_raw_text_dataset(args.data, splits,
                                             args.source_lang,
                                             args.target_lang)
    if args.source_lang is None or args.target_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.source_lang, args.target_lang = dataset.src, dataset.dst

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    args.num_gpus = torch.cuda.device_count()

    print(args)
    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))
    for split in splits:
        print('| {} {} {} examples'.format(args.data, split,
                                           len(dataset.splits[split])))

    print(
        '| using {} GPUs (with max tokens per GPU = {} and max sentences per GPU = {})'
        .format(args.num_gpus, args.max_tokens, args.max_sentences))

    # Build model and criterion
    model = utils.build_model(args, dataset.src_dict, dataset.dst_dict)
    criterion = utils.build_criterion(args, dataset.src_dict, dataset.dst_dict)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {}'.format(
        sum(p.data.numel() for p in model.parameters())))

    # The max number of positions can be different for train and valid
    # e.g., RNNs may support more positions at test time than seen in training
    max_positions_train = (min(args.max_source_positions,
                               model.max_encoder_positions()),
                           min(args.max_target_positions,
                               model.max_decoder_positions()))
    max_positions_valid = (model.max_encoder_positions(),
                           model.max_decoder_positions())

    # Start multiprocessing
    trainer = MultiprocessingTrainer(args, model, criterion)

    # Create files to save losses
    traincsv_path = os.path.join(args.save_dir, 'train_losses.csv')
    validcsv_path = os.path.join(args.save_dir, 'valid_losses.csv')
    output_path = [traincsv_path, validcsv_path]
    for path in output_path:
        with open(path, 'w+') as csvfile:
            csvwriter = csv.writer(csvfile, delimiter=',')
            csvwriter.writerow(['Epoch', 'Perplexity', 'Loss'])
            csvfile.close()

    # Load the latest checkpoint if one is available
    checkpoint_path = os.path.join(args.save_dir, args.restore_file)
    extra_state = trainer.load_checkpoint(checkpoint_path)
    if extra_state is not None:
        epoch = extra_state['epoch']
        batch_offset = extra_state['batch_offset']
        print('| loaded checkpoint {} (epoch {})'.format(
            checkpoint_path, epoch))
        if batch_offset == 0:
            epoch += 1
    else:
        epoch, batch_offset = 1, 0

    # Train until the learning rate gets too small
    val_loss = None
    max_epoch = args.max_epoch or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    while lr > args.min_lr and epoch <= max_epoch:
        # train for one epoch
        train(args, epoch, batch_offset, trainer, dataset, max_positions_train,
              traincsv_path)

        # evaluate on validate set
        for k, subset in enumerate(args.valid_subset.split(',')):
            val_loss = validate(args, epoch, trainer, dataset,
                                max_positions_valid, subset, validcsv_path)
            if k == 0:
                if not args.no_save:
                    # save checkpoint
                    save_checkpoint(trainer, args, epoch, 0, val_loss)
                # only use first validation loss to update the learning schedule
                lr = trainer.lr_step(val_loss, epoch)

        epoch += 1
        batch_offset = 0
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))

    # Stop multiprocessing
    trainer.stop()
Ejemplo n.º 7
0
class FairseqAgent(Agent):
    """Agent which takes an input sequence and produces an output sequence.

    For more information, see Convolutional Sequence to Sequence Learning
     `(Gehring et al. 2017) <https://arxiv.org/abs/1705.03122>`_.
    """

    @staticmethod
    def add_cmdline_args(argparser):
        """Add command-line arguments specifically for this agent."""
        DictionaryAgent.add_cmdline_args(argparser)
        agent = argparser.add_argument_group('Fairseq Arguments')
        agent.add_argument(
            '-tr', '--truncate',
            type=int, default=-1,
            help='truncate input & output lengths to speed up training (may '
                 'reduce accuracy). This fixes all input and output to have a '
                 'maximum length. This reduces the total amount of padding in '
                 'the batches.')
        agent.add_argument(
            '--max-positions',
            default=1024,
            type=int,
            metavar='N',
            help='max number of tokens in the sequence')
        agent.add_argument(
            '--seed',
            default=1,
            type=int,
            metavar='N',
            help='pseudo random number generator seed')
        options.add_optimization_args(argparser)
        options.add_generation_args(argparser)
        options.add_model_args(argparser)

    def __init__(self, opt, shared=None):
        # initialize defaults first
        super().__init__(opt, shared)
        if not shared:
            # this is not a shared instance of this class, so do full
            # initialization. if shared is set, only set up shared members.
            saved_state = None
            if opt.get('model_file') and os.path.isfile(opt['model_file']):
                # load model parameters if available
                print('Loading existing model params from ' +
                      opt['model_file'])
                new_opt, saved_state = self.load(opt['model_file'])
                # override options with stored ones
                opt = self._override_opt(new_opt)

            self.args = OptWrapper(opt)
            self.fairseq_dict = _make_fairseq_dict(DictionaryAgent(opt))
            self.id = 'Fairseq'
            self.truncate = opt['truncate'] if opt['truncate'] > 0 else None

            self.EOS = self.fairseq_dict[self.fairseq_dict.eos()]
            self.EOS_TENSOR = (torch.LongTensor(1, 1)
                               .fill_(self.fairseq_dict.eos()))
            self.NULL_IDX = self.fairseq_dict.pad()

            encoder = fconv.Encoder(
                self.fairseq_dict,
                embed_dim=self.args.encoder_embed_dim,
                convolutions=eval(self.args.encoder_layers),
                dropout=self.args.dropout,
                max_positions=self.args.max_positions)
            decoder = fconv.Decoder(
                self.fairseq_dict,
                embed_dim=self.args.decoder_embed_dim,
                convolutions=eval(self.args.decoder_layers),
                out_embed_dim=self.args.decoder_out_embed_dim,
                attention=eval(self.args.decoder_attention),
                dropout=self.args.dropout,
                max_positions=self.args.max_positions)
            self.model = fconv.FConvModel(encoder, decoder)

            # from fairseq's build_criterion()
            if self.args.label_smoothing > 0:
                self.criterion = criterions.LabelSmoothedCrossEntropyCriterion(
                    self.args.label_smoothing, self.NULL_IDX)
            else:
                self.criterion = criterions.CrossEntropyCriterion(
                    self.NULL_IDX)

            self.trainer = MultiprocessingTrainer(self.args, self.model, self.criterion)
            if saved_state is not None:
                self.set_states(saved_state)
        self.reset()

    def _override_opt(self, new_opt):
        """Set overridable opts from loaded opt file.

        Print out each added key and each overriden key.
        Only override args specific to the model.
        """
        model_args = {
            'arch',
            'encoder-embed-dim',
            'encoder-layers',
            'decoder-embed-dim',
            'decoder-layers',
            'decoder-out-embed-dim',
            'decoder-attention',
        }

        for k, v in new_opt.items():
            if k not in model_args:
                # skip non-model args
                continue
            if k not in self.opt:
                print('Adding new option [ {k}: {v} ]'.format(k=k, v=v))
            elif self.opt[k] != v:
                print('Overriding option [ {k}: {old} => {v}]'.format(
                    k=k, old=self.opt[k], v=v))
            self.opt[k] = v
        return self.opt

    def reset(self):
        """Reset observation and episode_done."""
        self.observation = None
        self.episode_done = True

    def observe(self, observation):
        # shallow copy observation (deep copy can be expensive)
        observation = observation.copy()
        if not self.episode_done:
            # if the last example wasn't the end of an episode, then we need to
            # recall what was said in that example
            prev_dialogue = self.observation['text']
            observation['text'] = prev_dialogue + '\n' + observation['text']
        self.observation = observation
        self.episode_done = observation['episode_done']
        return observation

    def act(self):
        # call batch_act with this batch of one
        return self.batch_act([self.observation])[0]

    def batch_act(self, observations):
        bsz = len(observations)
        # initialize a table of replies with this agent's id
        batch_reply = [{'id': self.getID()} for _ in range(bsz)]

        # convert the observations into batches of inputs and targets
        # valid_inds tells us the indices of all valid examples
        # e.g. for input [{}, {'text': 'hello'}, {}, {}], valid_inds is [1]
        # since the other three elements had no 'text' field

        # also, split observations into sub-batches based on number of gpus
        obs_split = np.array_split(observations, self.trainer.num_replicas)
        samples = [self.batchify(obs) for obs in obs_split]
        samples = [s for s in samples if s[0] is not None]
        any_valid = any(len(s[0]) > 0 for s in samples)

        if not any_valid:
            # no valid examples, just return the empty responses we set up
            return batch_reply

        # produce predictions if testing; otherwise, train
        has_targets = any(s[1] is not None for s in samples)
        if not has_targets:
            offset = 0
            for s in samples:
                xs = s[0]
                valid_inds = s[2]

                predictions = self._generate(self.args, xs)
                for i in range(len(predictions)):
                    # map the predictions back to non-empty examples in the batch
                    batch_reply[valid_inds[i] + offset]['text'] = predictions[i]
                    if i == 0:
                        print('prediction:', predictions[i])
                offset += len(valid_inds)
        else:
            loss = self._train(samples)

            batch_reply[0]['metrics'] = {}
            for k, v in loss.items():
                batch_reply[0]['metrics'][k] = v * bsz
                if k == 'loss':
                    batch_reply[0]['metrics']['perplexity'] = 2 ** v * bsz

        return batch_reply

    def parse(self, string):
        return [self.fairseq_dict.index(word) for word in string.split(' ')]

    def batchify(self, observations):
        """Convert a list of observations into input & target tensors."""
        # valid examples
        exs = [ex for ex in observations if 'text' in ex]
        # the indices of the valid (non-empty) tensors
        valid_inds = [i for i, ex in enumerate(observations) if 'text' in ex]

        # set up the input tensors
        batchsize = len(exs)
        if batchsize == 0:
            return None, None, None
        # tokenize the text
        parsed_x = [deque(maxlen=self.truncate) for _ in exs]
        for dq, ex in zip(parsed_x, exs):
            dq += self.parse(ex['text'])
        # parsed = [self.parse(ex['text']) for ex in exs]
        max_x_len = max((len(x) for x in parsed_x))
        for x in parsed_x:
            # left pad with zeros
            x.extendleft([self.fairseq_dict.pad()] * (max_x_len - len(x)))
        xs = torch.LongTensor(parsed_x)

        # set up the target tensors
        ys = None
        if 'labels' in exs[0]:
            # randomly select one of the labels to update on, if multiple
            labels = [random.choice(ex.get('labels', [''])) for ex in exs]
            parsed_y = [deque(maxlen=self.truncate) for _ in labels]
            for dq, y in zip(parsed_y, labels):
                dq.extendleft(reversed(self.parse(y)))
            for y in parsed_y:
                y.append(self.fairseq_dict.eos())
            # append EOS to each label
            max_y_len = max(len(y) for y in parsed_y)
            for y in parsed_y:
                y += [self.fairseq_dict.pad()] * (max_y_len - len(y))
            ys = torch.LongTensor(parsed_y)
        return xs, ys, valid_inds

    def _positions_for_tokens(self, tokens):
        size = tokens.size()
        not_pad = tokens.ne(self.fairseq_dict.pad()).long()
        new_pos = tokens.new(size).fill_(self.fairseq_dict.pad())
        new_pos += not_pad
        for i in range(1, size[1]):
            new_pos[:, i] += new_pos[:, i-1] - 1
        return new_pos

    def _right_shifted_ys(self, ys):
        result = torch.LongTensor(ys.size())
        result[:, 0] = self.fairseq_dict.index(self.EOS)
        result[:, 1:] = ys[:, :-1]
        return result

    def _generate(self, opt, src_tokens):
        if not hasattr(self, 'translator'):
            self.translator = SequenceGenerator(
                [self.trainer.get_model()],
                beam_size=opt.beam,
                stop_early=(not opt.no_early_stop),
                normalize_scores=(not opt.unnormalized),
                len_penalty=opt.lenpen)
            self.translator.cuda()
        tokens = src_tokens.cuda(async=True)
        token_pos = Variable(self._positions_for_tokens(tokens).cuda())
        translations = self.translator.generate(Variable(tokens), token_pos)
        results = [t[0] for t in translations]
        output_lines = [[] for _ in range(len(results))]
        for i in range(len(results)):
            output_lines[i] = ' '.join(self.fairseq_dict[idx]
                                       for idx in results[i]['tokens'][:-1])
        return output_lines

    def _train(self, samples):
        """Update the model using the targets."""
        for i, sample in enumerate(samples):
            # add extra info to samples
            sample = {
                'src_tokens': sample[0],
                'input_tokens': self._right_shifted_ys(sample[1]),
                'target': sample[1],
                'id': None
            }
            sample['ntokens'] = sum(len(t) for t in sample['target'])
            sample['src_positions'] = self._positions_for_tokens(
                sample['src_tokens'])
            sample['input_positions'] = self._positions_for_tokens(
                sample['input_tokens'])
            samples[i] = sample
        return self.trainer.train_step(samples)

    def save(self, path=None):
        path = self.opt.get('model_file', None) if path is None else path
        if path and hasattr(self, 'trainer'):
            model = {}
            model['state_dict'] = self.trainer.get_model().state_dict()
            model['opt'] = self.opt
            with open(path, 'wb') as write:
                torch.save(model, write)

    def shutdown(self):
        """Save the state of the model when shutdown."""
        path = self.opt.get('model_file', None)
        if path is not None:
            self.save(path + '.shutdown_state')
        super().shutdown()

    def load(self, path):
        """Return opt and model states."""
        with open(path, 'rb') as read:
            model = torch.load(read)
        return model['opt'], model['state_dict']

    def set_states(self, state_dict):
        """Set the state dict of the model from saved states."""
        self.trainer.get_model().load_state_dict(state_dict)
Ejemplo n.º 8
0
def main():
    parser = options.get_parser('Trainer')
    dataset_args = options.add_dataset_args(parser)
    dataset_args.add_argument('--max-tokens', default=6000, type=int, metavar='N',
                              help='maximum number of tokens in a batch')
    dataset_args.add_argument('--max-sentences', type=int, metavar='N',
                              help='maximum number of sentences in a batch')
    dataset_args.add_argument('--train-subset', default='train', metavar='SPLIT',
                              choices=['train', 'valid', 'test'],
                              help='data subset to use for training (train, valid, test)')
    dataset_args.add_argument('--valid-subset', default='valid', metavar='SPLIT',
                              help='comma separated list of data subsets '
                                   ' to use for validation (train, valid, valid1,test, test1)')
    options.add_optimization_args(parser)
    options.add_checkpoint_args(parser)
    options.add_model_args(parser)

    args = utils.parse_args_and_arch(parser)

    if args.no_progress_bar and args.log_format is None:
        args.log_format = 'simple'

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    torch.manual_seed(args.seed)

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        dataset = data.load_dataset(args.data, splits, args.source_lang, args.target_lang)
    else:
        dataset = data.load_raw_text_dataset(args.data, splits, args.source_lang, args.target_lang)
    if args.source_lang is None or args.target_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.source_lang, args.target_lang = dataset.src, dataset.dst

    print(args)
    print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
    for split in splits:
        print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split])))

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    num_gpus = torch.cuda.device_count()

    print('| using {} GPUs (with max tokens per GPU = {} and max sentences per GPU = {})'.format(
        num_gpus, args.max_tokens, args.max_sentences))

    # Build model and criterion
    model = utils.build_model(args, dataset.src_dict, dataset.dst_dict)
    criterion = utils.build_criterion(args, dataset.src_dict, dataset.dst_dict)
    print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))

    # The max number of positions can be different for train and valid
    # e.g., RNNs may support more positions at test time than seen in training
    max_positions_train = (args.max_source_positions, args.max_target_positions)
    max_positions_valid = (
        min(args.max_source_positions, model.max_encoder_positions()),
        min(args.max_target_positions, model.max_decoder_positions())
    )

    # Start multiprocessing
    trainer = MultiprocessingTrainer(args, model, criterion)

    # Load the latest checkpoint if one is available
    checkpoint_path = os.path.join(args.save_dir, args.restore_file)
    extra_state = trainer.load_checkpoint(checkpoint_path)
    if extra_state is not None:
        epoch = extra_state['epoch']
        batch_offset = extra_state['batch_offset']
        print('| loaded checkpoint {} (epoch {})'.format(checkpoint_path, epoch))
        if batch_offset == 0:
            epoch += 1
    else:
        epoch, batch_offset = 1, 0

    # Train until the learning rate gets too small
    val_loss = None
    max_epoch = args.max_epoch or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    while lr > args.min_lr and epoch <= max_epoch:
        # train for one epoch
        train(args, epoch, batch_offset, trainer, dataset, max_positions_train, num_gpus)

        # evaluate on validate set
        for k, subset in enumerate(args.valid_subset.split(',')):
            val_loss = validate(args, epoch, trainer, dataset, max_positions_valid, subset, num_gpus)
            if k == 0:
                if not args.no_save:
                    # save checkpoint
                    save_checkpoint(trainer, args, epoch, 0, val_loss)
                # only use first validation loss to update the learning schedule
                lr = trainer.lr_step(val_loss, epoch)

        epoch += 1
        batch_offset = 0
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))

    # Stop multiprocessing
    trainer.stop()
Ejemplo n.º 9
0
class FairseqAgent(Agent):
    """Agent which takes an input sequence and produces an output sequence.

    For more information, see Convolutional Sequence to Sequence Learning
     `(Gehring et al. 2017) <https://arxiv.org/abs/1705.03122>`_.
    """

    @staticmethod
    def add_cmdline_args(argparser):
        """Add command-line arguments specifically for this agent."""
        DictionaryAgent.add_cmdline_args(argparser)
        agent = argparser.add_argument_group('Fairseq Arguments')

        agent.add_argument(
            '--max-positions',
            default=1024,
            type=int,
            metavar='N',
            help='max number of tokens in the sequence')
        agent.add_argument(
            '--seed',
            default=1,
            type=int,
            metavar='N',
            help='pseudo random number generator seed')
        agent.add_argument(
            '--lr',
            '--learning-rate',
            default=0.25,
            type=float,
            metavar='LR',
            help='initial learning rate')
        agent.add_argument(
            '--momentum',
            default=0.99,
            type=float,
            metavar='M',
            help='momentum factor')
        agent.add_argument(
            '--weight-decay',
            '--wd',
            default=0.0,
            type=float,
            metavar='WD',
            help='weight decay')
        agent.add_argument(
            '--force-anneal',
            '--fa',
            default=0,
            type=int,
            metavar='N',
            help='force annealing at specified epoch')
        agent.add_argument(
            '--beam', default=5, type=int, metavar='N', help='beam size')
        agent.add_argument(
            '--no-early-stop',
            action='store_true',
            help=('continue searching even after finalizing k=beam '
                  'hypotheses; this is more correct, but increases '
                  'generation time by 50%%'))
        agent.add_argument(
            '--unnormalized',
            action='store_true',
            help='compare unnormalized hypothesis scores')

        agent.add_argument(
            '--lenpen',
            default=1,
            type=float,
            help=
            'length penalty: <1.0 favors shorter, >1.0 favors longer sentences')

        agent.add_argument(
            '--clip-norm',
            default=25,
            type=float,
            metavar='NORM',
            help='clip threshold of gradients')

        agent.add_argument(
            '--arch',
            '-a',
            default='fconv',
            metavar='ARCH',
            choices=models.arch_model_map.keys(),
            help='model architecture ({})'.format(
                ', '.join(models.arch_model_map.keys())))
        agent.add_argument(
            '--encoder-embed-dim',
            type=int,
            metavar='N',
            help='encoder embedding dimension')
        agent.add_argument(
            '--encoder-layers',
            type=str,
            metavar='EXPR',
            help='encoder layers [(dim, kernel_size), ...]')
        agent.add_argument(
            '--decoder-embed-dim',
            type=int,
            metavar='N',
            help='decoder embedding dimension')
        agent.add_argument(
            '--decoder-layers',
            type=str,
            metavar='EXPR',
            help='decoder layers [(dim, kernel_size), ...]')
        agent.add_argument(
            '--decoder-out-embed-dim',
            type=int,
            metavar='N',
            help='decoder output embedding dimension')
        agent.add_argument(
            '--decoder-attention',
            type=str,
            metavar='EXPR',
            help='decoder attention [True, ...]')

        # These arguments have default values independent of the model:
        agent.add_argument(
            '--dropout',
            default=0.1,
            type=float,
            metavar='D',
            help='dropout probability')
        agent.add_argument(
            '--label-smoothing',
            default=0,
            type=float,
            metavar='D',
            help='epsilon for label smoothing, 0 means no label smoothing')

    def __init__(self, opt, shared=None):
        # initialize defaults first
        super().__init__(opt, shared)
        if not shared:
            # this is not a shared instance of this class, so do full
            # initialization. if shared is set, only set up shared members.

            if opt.get('model_file') and os.path.isfile(opt['model_file']):
                # load model parameters if available
                print('Loading existing model params from ' +
                      opt['model_file'])
                new_opt, self.saved_state = self.load(opt['model_file'])
                # override options with stored ones
                opt = self._override_opt(new_opt)

            self.args = OptWrapper(opt)
            self.fairseq_dict = _make_fairseq_dict(DictionaryAgent(opt))
            self.id = 'Fairseq'

            self.EOS = self.fairseq_dict[self.fairseq_dict.eos()]
            self.NULL_IDX = self.fairseq_dict.pad()

            encoder = fconv.Encoder(
                len(self.fairseq_dict),
                embed_dim=self.args.encoder_embed_dim,
                convolutions=eval(self.args.encoder_layers),
                dropout=self.args.dropout,
                padding_idx=self.NULL_IDX,
                max_positions=self.args.max_positions)
            decoder = fconv.Decoder(
                len(self.fairseq_dict),
                embed_dim=self.args.decoder_embed_dim,
                convolutions=eval(self.args.decoder_layers),
                out_embed_dim=self.args.decoder_out_embed_dim,
                attention=eval(self.args.decoder_attention),
                dropout=self.args.dropout,
                padding_idx=self.NULL_IDX,
                max_positions=self.args.max_positions)
            self.model = fconv.FConvModel(encoder, decoder, self.NULL_IDX)

            # from fairseq's build_criterion()
            if self.args.label_smoothing > 0:
                self.criterion = criterions.LabelSmoothedCrossEntropyCriterion(
                    self.args.label_smoothing, self.NULL_IDX)
            else:
                self.criterion = criterions.CrossEntropyCriterion(
                    self.NULL_IDX)

            self.trainer = MultiprocessingTrainer(self.args, self.model)
            if hasattr(self, 'saved_state'):
                self.set_states(self.saved_state)

        self.reset()

    def _override_opt(self, new_opt):
        """Set overridable opts from loaded opt file.

        Print out each added key and each overriden key.
        Only override args specific to the model.
        """
        model_args = {
            'arch',
            'encoder-embed-dim',
            'encoder-layers',
            'decoder-embed-dim',
            'decoder-layers',
            'decoder-out-embed-dim',
            'decoder-attention',
        }

        for k, v in new_opt.items():
            if k not in model_args:
                # skip non-model args
                continue
            if k not in self.opt:
                print('Adding new option [ {k}: {v} ]'.format(k=k, v=v))
            elif self.opt[k] != v:
                print('Overriding option [ {k}: {old} => {v}]'.format(
                    k=k, old=self.opt[k], v=v))
            self.opt[k] = v
        return self.opt

    def reset(self):
        """Reset observation and episode_done."""
        self.observation = None
        self.episode_done = True

    def observe(self, observation):
        # shallow copy observation (deep copy can be expensive)
        observation = observation.copy()
        if not self.episode_done:
            # if the last example wasn't the end of an episode, then we need to
            # recall what was said in that example
            prev_dialogue = self.observation['text']
            observation['text'] = prev_dialogue + '\n' + observation['text']
        self.observation = observation
        self.episode_done = observation['episode_done']
        return observation

    def act(self):
        # call batch_act with this batch of one
        return self.batch_act([self.observation])[0]

    def batch_act(self, observations):
        batchsize = len(observations)
        # initialize a table of replies with this agent's id
        batch_reply = [{'id': self.getID()} for _ in range(batchsize)]

        # convert the observations into batches of inputs and targets
        # valid_inds tells us the indices of all valid examples
        # e.g. for input [{}, {'text': 'hello'}, {}, {}], valid_inds is [1]
        # since the other three elements had no 'text' field
        xs, ys, valid_inds = self.batchify(observations)

        if len(xs) == 0:
            # no valid examples, just return the empty responses we set up
            return batch_reply

        # produce predictions if testing; otherwise, train

        if ys is None:
            predictions = self._generate(self.args, xs)
            for i in range(len(predictions)):
                # map the predictions back to non-empty examples in the batch
                batch_reply[valid_inds[i]]['text'] = predictions[i]
        else:
            self._train(xs, ys)

        return batch_reply

    def parse(self, string):
        return [self.fairseq_dict.index(word) for word in string.split(' ')]

    def batchify(self, observations):
        """Convert a list of observations into input & target tensors."""
        # valid examples
        exs = [ex for ex in observations if 'text' in ex]
        # the indices of the valid (non-empty) tensors
        valid_inds = [i for i, ex in enumerate(observations) if 'text' in ex]

        # set up the input tensors
        batchsize = len(exs)
        # tokenize the text
        parsed = [self.parse(ex['text']) for ex in exs]
        max_x_len = max([len(x) for x in parsed])
        xs = torch.LongTensor(batchsize,
                              max_x_len).fill_(self.fairseq_dict.pad())
        # pack the data to the right side of the tensor for this model
        for i, x in enumerate(parsed):
            offset = max_x_len - len(x)
            for j, idx in enumerate(x):
                xs[i][j + offset] = idx
        xs = xs.cuda(async=True)
        # set up the target tensors
        ys = None
        if 'labels' in exs[0]:
            # randomly select one of the labels to update on, if multiple
            # append EOS to each label
            labels = [
                random.choice(ex['labels']) + ' ' + self.EOS for ex in exs
            ]
            parsed = [self.parse(y) for y in labels]
            max_y_len = max(len(y) for y in parsed)
            ys = torch.LongTensor(batchsize,
                                  max_y_len).fill_(self.fairseq_dict.pad())
            for i, y in enumerate(parsed):
                for j, idx in enumerate(y):
                    ys[i][j] = idx
            ys = ys.cuda(async=True)
        return xs, ys, valid_inds

    def _positions_for_tokens(self, tokens):
        start = self.fairseq_dict.pad() + 1
        size = tokens.size()
        positions = torch.LongTensor(size).fill_(self.fairseq_dict.pad())
        for i in range(size[0]):
            nonpad = 0
            for j in range(size[1]):
                if (tokens[i][j] != self.fairseq_dict.pad()):
                    positions[i][j] = start + nonpad
                    nonpad += 1
        positions = positions.cuda(async=True)
        return positions

    def _right_shifted_ys(self, ys):
        result = torch.LongTensor(ys.size())
        result[:, 0] = self.fairseq_dict.index(self.EOS)
        result[:, 1:] = ys[:, :-1]
        return result

    def _generate(self, opt, src_tokens):
        translator = SequenceGenerator(
            [self.trainer.get_model()],
            self.fairseq_dict,
            beam_size=opt.beam,
            stop_early=(not opt.no_early_stop),
            normalize_scores=(not opt.unnormalized),
            len_penalty=opt.lenpen)
        translator.cuda()
        tokens = src_tokens
        translations = translator.generate(
            Variable(tokens), Variable(self._positions_for_tokens(tokens)))
        results = [t[0] for t in translations]
        output_lines = [[] for _ in range(len(results))]
        for i in range(len(results)):
            output_lines[i] = ' '.join(self.fairseq_dict[idx]
                                       for idx in results[i]['tokens'][:-1])
        return output_lines

    def _train(self, xs, ys=None):
        """Produce a prediction from our model. Update the model using the
        targets if available.
        """
        if ys is not None:
            sample = {
                'src_tokens': xs,
                'input_tokens': self._right_shifted_ys(ys),
                'target': ys,
                'id': None
            }
            sample['ntokens'] = sum(len(t) for t in sample['target'])
            sample['src_positions'] = self._positions_for_tokens(
                sample['src_tokens'])
            sample['input_positions'] = self._positions_for_tokens(
                sample['input_tokens'])
            self.trainer.train_step([sample], self.criterion)

    def save(self, path=None):
        path = self.opt.get('model_file', None) if path is None else path
        if path and hasattr(self, 'trainer'):
            model = {}
            model['state_dict'] = self.trainer.get_model().state_dict()
            model['opt'] = self.opt
            with open(path, 'wb') as write:
                torch.save(model, write)

    def shutdown(self):
        """Save the state of the model when shutdown."""
        path = self.opt.get('model_file', None)
        if path is not None:
            self.save(path + '.shutdown_state')
        super().shutdown()

    def load(self, path):
        """Return opt and model states."""
        with open(path, 'rb') as read:
            model = torch.load(read)
        return model['opt'], model['state_dict']

    def set_states(self, state_dict):
        """Set the state dict of the model from saved states."""
        self.trainer.get_model().load_state_dict(state_dict)