예제 #1
0
    def _convert(self, model, global_input):
        if self.verbose:
            logger.info('Converting the model...')

        if not isinstance(global_input, tuple):
            global_input = (global_input, )
        self.num_global_inputs = len(global_input)
        self.device = global_input[0].device

        nodesOP, nodesIO = self._convert_nodes(model, global_input)

        while True:
            self._build_graph(nodesOP, nodesIO)
            self.forward(*global_input)
            nodesOP, nodesIO, found_complex = self._split_complex(
                nodesOP, nodesIO)
            if not found_complex: break

        for node in self.nodes:
            for p in list(node.named_parameters()):
                self.register_parameter('{}/{}'.format(node.name, p[0]), p[1])

        logger.debug('NodesOP:')
        for node in nodesOP:
            logger.debug('{}'.format(node._replace(param=None)))
        logger.debug('NodesIO')
        for node in nodesIO:
            logger.debug('{}'.format(node._replace(param=None)))

        if self.verbose:
            logger.info('Model converted to support bounds')
예제 #2
0
    def __init__(self, args, data_train):
        super(LSTM, self).__init__()
        self.args = args
        self.embedding_size = args.embedding_size
        self.max_seq_length = args.max_sent_length
        self.min_word_freq = args.min_word_freq
        self.device = args.device
        self.lr = args.lr

        self.dir = args.dir
        if not os.path.exists(self.dir):
            os.makedirs(self.dir)
        self.vocab = self.vocab_actual = build_vocab(data_train,
                                                     args.min_word_freq)
        self.checkpoint = 0
        if os.path.exists(os.path.join(self.dir, "checkpoint")):
            with open(os.path.join(self.dir, "checkpoint")) as file:
                self.checkpoint = int(file.readline())
            dir_ckpt = os.path.join(self.dir,
                                    "ckpt-{}".format(self.checkpoint))
            path = os.path.join(dir_ckpt, "model")
            self.model = torch.load(path)
            logger.info("Model loaded: {}".format(dir_ckpt))
        else:
            self.embedding = torch.nn.Embedding(len(self.vocab),
                                                self.embedding_size)
            self.model = self.embedding, LSTMFromEmbeddings(
                args, len(self.vocab))
            logger.info("Model initialized")
        self.embedding, self.model_from_embeddings = self.model
        self.embedding = self.embedding.to(self.device)
        self.model_from_embeddings = self.model_from_embeddings.to(self.device)
        self.word_embeddings = self.embedding
예제 #3
0
def gen_ref():
    if args.train:
        train()
    res_transformer, res_lstm = evaluate()
    with open('data/language.pkl', 'wb') as file:
        pickle.dump((res_transformer, res_lstm), file)
    logger.info('Reference results saved')
예제 #4
0
def train(epoch):
    model.train()
    train_batches = get_batches(data_train, args.batch_size)
    for a in avg:
        a.reset()
    eps_inc_per_step = 1.0 / (args.num_epochs_warmup * len(train_batches))
    for i, batch in enumerate(train_batches):
        eps = args.eps * min(
            eps_inc_per_step * ((epoch - 1) * len(train_batches) + i + 1), 1.0)
        acc, acc_robust, loss = res = step(model,
                                           ptb,
                                           batch,
                                           eps=eps,
                                           train=True)
        torch.nn.utils.clip_grad_norm_(model.core.parameters(), 5.0)
        optimizer.step()
        optimizer.zero_grad()
        for k in range(3):
            avg[k].update(res[k], len(batch))
        if (i + 1) % args.log_interval == 0:
            logger.info(
                "Epoch {}, training step {}/{}: acc {:.3f}, acc_robust {:.3f}, loss {:.3f}, eps {:.3f}"
                .format(epoch, i + 1, len(train_batches), avg_acc.avg,
                        avg_acc_robust.avg, avg_loss.avg, eps))
    model.save(epoch)
예제 #5
0
    def _convert(self, model, global_input):
        if self.verbose:
            logger.info('Converting the model...')

        if not isinstance(global_input, tuple):
            global_input = (global_input, )
        self.num_global_inputs = len(global_input)

        nodesOP, nodesIO = self._convert_nodes(model, global_input)
        global_input = tuple([i.to(self.device) for i in global_input])

        while True:
            self._build_graph(nodesOP, nodesIO)
            self.forward(*global_input)
            nodesOP, nodesIO, found_complex = self._split_complex(
                nodesOP, nodesIO)
            if not found_complex: break

        for node in self.nodes:
            for p in list(node.named_parameters()):
                if node.ori_name not in self._parameters:
                    # For parameter or input nodes, use their original name directly
                    self._parameters[node.ori_name] = p[1]

        logger.debug('NodesOP:')
        for node in nodesOP:
            logger.debug('{}'.format(node._replace(param=None)))
        logger.debug('NodesIO')
        for node in nodesIO:
            logger.debug('{}'.format(node._replace(param=None)))

        if self.verbose:
            logger.info('Model converted to support bounds')
예제 #6
0
    def _convert(self, model, global_input):
        if self.verbose:
            logger.info('Converting the model...')

        if not isinstance(global_input, tuple):
            global_input = (global_input, )
        self.num_global_inputs = len(global_input)

        nodesOP, nodesIO = self._convert_nodes(model, global_input)
        global_input = tuple([i.to(self.device) for i in global_input])

        while True:
            self._build_graph(nodesOP, nodesIO)
            self.forward(*global_input)  # running means/vars changed
            nodesOP, nodesIO, found_complex = self._split_complex(
                nodesOP, nodesIO)
            if not found_complex: break

        self._get_node_name_map()

        # load self.ori_state_dict again to avoid the running means/vars changed during forward()
        self.load_state_dict(self.ori_state_dict)
        model.load_state_dict(self.ori_state_dict)
        delattr(self, 'ori_state_dict')

        logger.debug('NodesOP:')
        for node in nodesOP:
            logger.debug('{}'.format(node._replace(param=None)))
        logger.debug('NodesIO')
        for node in nodesIO:
            logger.debug('{}'.format(node._replace(param=None)))

        if self.verbose:
            logger.info('Model converted to support bounds')
예제 #7
0
def test():
    if not os.path.exists('../examples/language/data'):
        prepare_data()
    if args.gen_ref:
        gen_ref()
    else:
        check()
    logger.info("test_Language done")
예제 #8
0
def load_config(path):
    with open("config/defaults.json") as f:
        config = json.load(f)
    if path is not None:
        logger.info("Loading config file: {}".format(path))
        with open(path) as f:
            update_dict(config, json.load(f))
    return config
예제 #9
0
def train(epoch, batches, type):
    meter = MultiAverageMeter()
    assert(optimizer is not None) 
    train = type == 'train'
    if args.robust:
        eps_scheduler.set_epoch_length(len(batches))    
        if train:
            eps_scheduler.train()
            eps_scheduler.step_epoch()
        else:
            eps_scheduler.eval()
    for i, batch in enumerate(batches):
        if args.robust:
            eps_scheduler.step_batch()
            eps = eps_scheduler.get_eps()
        else:
            eps = 0
        acc, loss, acc_robust, loss_robust = \
            step(model, ptb, batch, eps=eps, train=train)
        meter.update('acc', acc, len(batch))
        meter.update('loss', loss, len(batch))
        meter.update('acc_rob', acc_robust, len(batch))
        meter.update('loss_rob', loss_robust, len(batch))
        if train:
            if (i + 1) % args.gradient_accumulation_steps == 0 or (i + 1) == len(batches):
                scale_gradients(optimizer, i % args.gradient_accumulation_steps + 1, args.grad_clip)
                optimizer.step()
                optimizer.zero_grad()    
            if lr_scheduler is not None:
                lr_scheduler.step()                    
            writer.add_scalar('loss_train_{}'.format(epoch), meter.avg('loss'), i + 1)
            writer.add_scalar('loss_robust_train_{}'.format(epoch), meter.avg('loss_rob'), i + 1)
            writer.add_scalar('acc_train_{}'.format(epoch), meter.avg('acc'), i + 1)
            writer.add_scalar('acc_robust_train_{}'.format(epoch), meter.avg('acc_rob'), i + 1)
        if (i + 1) % args.log_interval == 0 or (i + 1) == len(batches):
            logger.info('Epoch {}, {} step {}/{}: eps {:.5f}, {}'.format(
                epoch, type, i + 1, len(batches), eps, meter))
            if lr_scheduler is not None:
                logger.info('lr {}'.format(lr_scheduler.get_lr()))
    writer.add_scalar('loss/{}'.format(type), meter.avg('loss'), epoch)
    writer.add_scalar('loss_robust/{}'.format(type), meter.avg('loss_rob'), epoch)
    writer.add_scalar('acc/{}'.format(type), meter.avg('acc'), epoch)
    writer.add_scalar('acc_robust/{}'.format(type), meter.avg('acc_rob'), epoch)

    if train:
        if args.loss_fusion:
            state_dict_loss = model_loss.state_dict() 
            state_dict = {}
            for name in state_dict_loss:
                assert(name.startswith('model.'))
                state_dict[name[6:]] = state_dict_loss[name]
            model_ori.load_state_dict(state_dict)
            model_bound = BoundedModule(
                model_ori, (dummy_embeddings, dummy_mask), bound_opts=bound_opts, device=args.device)
            model.model_from_embeddings = model_bound        
        model.save(epoch)

    return meter.avg('acc_rob')
예제 #10
0
 def save(self, epoch):
     self.model.model_from_embeddings = self.model_from_embeddings
     path = os.path.join(self.dir, "ckpt_{}".format(epoch))
     torch.save({ 
         'state_dict_embeddings': self.model.embeddings.state_dict(), 
         'state_dict_model_from_embeddings': self.model.model_from_embeddings.state_dict(), 
         'epoch': epoch
     }, path)
     logger.info("Model saved to {}".format(path))
예제 #11
0
def evaluate():
    logger.info('Evaluating the trained Transformer')
    os.system(cmd_transformer_test)
    res_transformer = read_res()
    logger.info('Evaluating the trained LSTM')
    os.system(cmd_lstm_test)
    res_lstm = read_res()
    os.system("rm {}".format(res_path))
    return res_transformer, res_lstm
예제 #12
0
def infer(epoch, batches, type):
    model.eval()
    for a in avg:
        a.reset()
    for i, batch in enumerate(batches):
        acc, acc_robust, loss = res = step(model, ptb, batch)
        for k in range(3):
            avg[k].update(res[k], len(batch))
    logger.info(
        "Epoch {}, {}: acc {:.3f}, acc_robust {:.3f}, loss {:.5f}".format(
            epoch, type, avg_acc.avg, avg_acc_robust.avg, avg_loss.avg))
예제 #13
0
def train():
    if os.path.exists("../examples/language/model_transformer_test"):
        os.system("rm -rf ../examples/language/model_transformer_test")
    if os.path.exists("../examples/language/model_lstm_test"):
        os.system("rm -rf ../examples/language/model_lstm_test")
    logger.info("Training a Transformer")
    os.system(cmd_transformer_train)
    os.system("cp ../examples/language/model_transformer_test/ckpt_2 data/ckpt_transformer")
    logger.info("Training an LSTM")
    os.system(cmd_lstm_train)
    os.system("cp ../examples/language/model_lstm_test/ckpt_2 data/ckpt_lstm")
예제 #14
0
 def save(self, epoch):
     path = os.path.join(self.dir, 'ckpt_{}'.format(epoch))
     torch.save(
         {
             'state_dict_embedding':
             self.embedding.state_dict(),
             'state_dict_model_from_embeddings':
             self.model_from_embeddings.state_dict(),
             'epoch':
             epoch
         }, path)
     logger.info('LSTM saved: {}'.format(path))
예제 #15
0
def load_data_sst():
    # training data
    path = "data/sst/train-nodes.tsv"
    logger.info("Loading data {}".format(path))
    data_train_warmup = []
    with open(path) as file:
        for line in file.readlines()[1:]:
            data_train_warmup.append({
                "sentence": line.split("\t")[0],
                "label": int(line.split("\t")[1])
            })

    # train/dev/test data
    for subset in ["train", "dev", "test"]:
        path = "data/sst/{}.txt".format(subset)
        logger.info("Loading data {}".format(path))
        data = []
        with open(path) as file:
            for line in file.readlines():
                segs = line[:-1].split(" ")
                tokens, word_labels = [], []
                label = int(segs[0][1])
                if label < 2:
                    label = 0
                elif label >= 3:
                    label = 1
                else:
                    continue
                for i in range(len(segs) - 1):
                    if segs[i][0] == "(" and segs[i][1] in ["0", "1", "2", "3", "4"]\
                            and segs[i + 1][0] != "(":
                        tokens.append(segs[i + 1][:segs[i + 1].find(")")])
                        word_labels.append(int(segs[i][1]))
                data.append({
                    "label": label,
                    "sentence": " ".join(tokens),
                    "word_labels": word_labels
                })
        for example in data:
            for i, token in enumerate(example["sentence"]):
                if token == "-LRB-":
                    example["sentence"][i] = "("
                if token == "-RRB-":
                    example["sentence"][i] = ")"
        if subset == "train":
            data_train = data
        elif subset == "dev":
            data_dev = data
        else:
            data_test = data

    return data_train_warmup, data_train, data_dev, data_test
예제 #16
0
    def save(self, epoch):
        output_dir = os.path.join(self.dir, "ckpt-%d" % epoch)
        if os.path.exists(output_dir):
            shutil.rmtree(output_dir)
        os.mkdir(output_dir)

        path = os.path.join(output_dir, "model")
        torch.save(self.core, path)

        with open(os.path.join(self.dir, "checkpoint"), "w") as file:
            file.write(str(epoch))

        logger.info("LSTM saved: %s" % output_dir)
예제 #17
0
def main():
    if args.train:
        for t in range(model.checkpoint, args.num_epochs):
            if t + 1 <= args.num_epochs_all_nodes:
                train(t + 1, get_batches(data_train_all_nodes,
                                         args.batch_size), 'train')
            else:
                train(t + 1, get_batches(data_train, args.batch_size), 'train')
            train(t + 1, dev_batches, 'dev')
            train(t + 1, test_batches, 'test')
    elif args.oracle:
        oracle(args, model, ptb, data_test, 'test')
    else:
        if args.robust:
            for i in range(args.num_epochs):
                eps_scheduler.step_epoch(verbose=False)
            res = []
            for i in range(1, args.budget + 1):
                logger.info('budget {}'.format(i))
                ptb.budget = i
                acc_rob = train(None, test_batches, 'test')
                res.append(acc_rob)
            logger.info('Verification results:')
            for i in range(len(res)):
                logger.info('budget {} acc_rob {:.3f}'.format(i + 1, res[i]))
            logger.info(res)
        else:
            train(None, test_batches, 'test')
예제 #18
0
 def _build_actual_vocab(self, args, vocab, data_train):
     vocab_actual = {}
     for example in data_train:
         for token in example["sentence"].strip().lower().split():
             if token in vocab:
                 if not token in vocab_actual:
                     vocab_actual[token] = 1
                 else:
                     vocab_actual[token] += 1
     for w in list(vocab_actual.keys()):
         if vocab_actual[w] < self.min_word_freq:
             del (vocab_actual[w])
     logger.info("Size of the vocabulary for perturbation: {}".format(
         len(vocab_actual)))
     return vocab_actual
def prepare_model(args, logger, config):
    model = args.model

    if config['data'] == 'MNIST':
        input_shape = (1, 28, 28)
    elif config['data'] == 'CIFAR':
        input_shape = (3, 32, 32)
    elif config['data'] == 'tinyimagenet':
        input_shape = (3, 64, 64)
    else:
        raise NotImplementedError(config['data'])

    model_ori = eval(model)(in_ch=input_shape[0],
                            in_dim=input_shape[1],
                            **parse_opts(args.model_params))

    checkpoint = None
    if args.auto_load:
        path_last = os.path.join(args.dir, 'ckpt_last')
        if os.path.exists(path_last):
            args.load = path_last
            logger.info('Use last checkpoint {}'.format(path_last))
        else:
            latest = -1
            for filename in os.listdir(args.dir):
                if filename.startswith('ckpt_'):
                    latest = max(latest, int(filename[5:]))
            if latest != -1:
                args.load = os.path.join(args.dir, 'ckpt_{}'.format(latest))
                try:
                    checkpoint = torch.load(args.load)
                except:
                    logger.warning('Cannot load {}'.format(args.load))
                    args.load = os.path.join(args.dir,
                                             'ckpt_{}'.format(latest - 1))
                    logger.warning('Trying {}'.format(args.load))
    if checkpoint is None and args.load:
        checkpoint = torch.load(args.load)
    if checkpoint is not None:
        epoch, state_dict = checkpoint['epoch'], checkpoint['state_dict']
        best = checkpoint.get('best', (100., 100., -1))
        model_ori.load_state_dict(state_dict, strict=False)
        logger.info(f'Checkpoint loaded: {args.load}, epoch {epoch}')
    else:
        epoch = 0
        best = (100., 100., -1)

    return model_ori, checkpoint, epoch, best
예제 #20
0
    def __init__(self, args, vocab_size):
        super(LSTMFromEmbeddings, self).__init__()

        self.embedding_size = args.embedding_size
        self.hidden_size = args.hidden_size
        self.num_classes = args.num_classes
        self.device = args.device

        self.cell_f = nn.LSTMCell(self.embedding_size, self.hidden_size)
        self.cell_b = nn.LSTMCell(self.embedding_size, self.hidden_size)
        self.linear = nn.Linear(self.hidden_size * 2, self.num_classes)
        if args.dropout is not None:
            self.dropout = nn.Dropout(p=args.dropout)
            logger.info('LSTM dropout: {}'.format(args.dropout))
        else:
            self.dropout = None
def get_optimizer(args, params, checkpoint=None):
    if args.opt == 'SGD':
        opt = optim.SGD(params,
                        lr=args.lr,
                        momentum=args.momentum,
                        weight_decay=args.weight_decay)
    else:
        opt = eval('optim.' + args.opt)(params,
                                        lr=args.lr,
                                        weight_decay=args.weight_decay)
    logger.info(f'Optimizer {opt}')
    if checkpoint:
        if 'optimizer' not in checkpoint:
            logger.error('Cannot find optimzier checkpoint')
        else:
            opt.load_state_dict(checkpoint['optimizer'])
    return opt
예제 #22
0
    def save(self, epoch):
        # the BoundGeneral object should be saved
        self.model.model_from_embeddings = self.model_from_embeddings
        model_to_save = self.model

        output_dir = os.path.join(self.dir, "ckpt-%d" % epoch)
        if os.path.exists(output_dir):
            shutil.rmtree(output_dir)
        os.mkdir(output_dir)

        path = os.path.join(output_dir, "model")
        torch.save(self.model, path)

        with open(os.path.join(self.dir, "checkpoint"), "w") as file:
            file.write(str(epoch))

        logger.info("BERT saved: %s" % output_dir)
예제 #23
0
 def step_epoch(self, verbose=True):
     self.epoch += 1
     self.batch = 0
     if self.epoch < self.schedule_start:
         self.epoch_start_eps = 0
         self.epoch_end_eps = 0
     else:
         eps_epoch = self.epoch - self.schedule_start
         eps_epoch_step = self.max_eps / self.schedule_length
         self.epoch_start_eps = min(eps_epoch * eps_epoch_step,
                                    self.max_eps)
         self.epoch_end_eps = min((eps_epoch + 1) * eps_epoch_step,
                                  self.max_eps)
     self.eps = self.epoch_start_eps
     if verbose:
         logger.info("Epoch {:3d} eps start {:7.5f} end {:7.5f}".format(
             self.epoch, self.epoch_start_eps, self.epoch_end_eps))
예제 #24
0
파일: BERT.py 프로젝트: Harry24k/auto_LiRPA
    def __init__(self, args, data_train):
        super(BERT, self).__init__()
        self.args = args
        self.max_seq_length = args.max_sent_length
        self.drop_unk = args.drop_unk
        self.num_labels = args.num_classes
        self.label_list = range(args.num_classes)
        self.device = args.device
        self.lr = args.lr

        self.dir = args.dir
        self.vocab = build_vocab(data_train, args.min_word_freq)
        if not os.path.exists(self.dir):
            os.makedirs(self.dir)
        self.checkpoint = 0
        if os.path.exists(os.path.join(self.dir, "checkpoint")):
            if args.checkpoint is not None:
                self.checkpoint = args.checkpoint
            else:
                with open(os.path.join(self.dir, "checkpoint")) as file:
                    self.checkpoint = int(file.readline())
            dir_ckpt = os.path.join(self.dir,
                                    "ckpt-{}".format(self.checkpoint))
            path = os.path.join(dir_ckpt, "model")
            self.model = torch.load(path)
            self.model.to(self.device)
            logger.info("Model loaded: {}".format(dir_ckpt))
        else:
            config = BertConfig(len(self.vocab))
            config.num_hidden_layers = args.num_layers
            config.embedding_size = args.embedding_size
            config.hidden_size = args.hidden_size
            config.intermediate_size = args.intermediate_size
            config.hidden_act = args.hidden_act
            config.num_attention_heads = args.num_attention_heads
            config.layer_norm = args.layer_norm
            config.hidden_dropout_prob = args.dropout
            self.model = BertForSequenceClassification(config,
                                                       self.num_labels,
                                                       vocab=self.vocab).to(
                                                           self.device)
            logger.info("Model initialized")

        self.model_from_embeddings = self.model.model_from_embeddings
        self.word_embeddings = self.model.embeddings.word_embeddings
        self.model_from_embeddings.device = self.device
예제 #25
0
 def __init__(self,
              model,
              global_input,
              verbose=False,
              bound_opts=None,
              device='cpu'):
     super(BoundedModule, self).__init__()
     if isinstance(model, BoundedModule):
         for key in model.__dict__.keys():
             setattr(self, key, getattr(model, key))
         return
     self.verbose = verbose
     self.bound_opts = bound_opts
     self.device = device
     if device == 'cpu':
         # in case that the device argument is missed
         logger.info('Using CPU for the BoundedModule')
     self._convert(model, global_input)
def save(args, epoch, best, model, opt, is_best=False):
    ckpt = {
        'state_dict': model.state_dict(),
        'optimizer': opt.state_dict(),
        'epoch': epoch,
        'best': best
    }
    path_last = os.path.join(args.dir, 'ckpt_last')
    if os.path.exists(path_last):
        os.system('mv {path} {path}.bak'.format(path=path_last))
    torch.save(ckpt, path_last)
    if is_best:
        path_best = os.path.join(args.dir, 'ckpt_best')
        if os.path.exists(path_best):
            os.system('mv {path} {path}.bak'.format(path=path_best))
        torch.save(ckpt, path_best)
    if args.save_all:
        torch.save(ckpt, os.path.join(args.dir, 'ckpt_{}'.format(epoch)))
    logger.info('')
예제 #27
0
    def save(self, epoch):
        self.model = (self.model[0], self.model_from_embeddings)

        model_to_save = self.model.module if hasattr(
            self.model,
            'module') else self.model  # Only save the model it-self

        output_dir = os.path.join(self.dir, "ckpt-%d" % epoch)
        if os.path.exists(output_dir):
            shutil.rmtree(output_dir)
        os.mkdir(output_dir)

        path = os.path.join(output_dir, "model")
        torch.save(self.model, path)

        with open(os.path.join(self.dir, "checkpoint"), "w") as file:
            file.write(str(epoch))

        logger.info("LSTM saved: %s" % output_dir)
예제 #28
0
def build_vocab(data_train, min_word_freq, dump=False, include=[]):
    vocab = {'[PAD]': 0, '[UNK]': 1, '[CLS]': 2, '[SEP]': 3, '[MASK]': 4}
    cnt = {}
    for example in data_train:
        for token in example['sentence'].strip().lower().split():
            if token in cnt:
                cnt[token] += 1
            else:
                cnt[token] = 1
    for w in cnt:
        if cnt[w] >= min_word_freq or w in include:
            vocab[w] = len(vocab)
    logger.info('Vocabulary size: {}'.format(len(vocab)))

    if dump:
        with open('tmp/vocab.txt', 'w') as file:
            for w in vocab.keys():
                file.write('{}\n'.format(w))

    return vocab
예제 #29
0
    def __init__(self, args, data_train):
        super(LSTM, self).__init__()
        self.args = args
        self.embedding_size = args.embedding_size
        self.max_seq_length = args.max_sent_length
        self.min_word_freq = args.min_word_freq
        self.device = args.device
        self.lr = args.lr

        self.dir = args.dir
        if not os.path.exists(self.dir):
            os.makedirs(self.dir)
        self.vocab = self.vocab_actual = build_vocab(data_train,
                                                     args.min_word_freq)
        self.checkpoint = 0

        if args.load:
            ckpt = torch.load(args.load,
                              map_location=torch.device(self.device))
            self.embedding = torch.nn.Embedding(len(self.vocab),
                                                self.embedding_size)
            self.model_from_embeddings = LSTMFromEmbeddings(
                args, len(self.vocab))
            self.model = self.embedding, LSTMFromEmbeddings(
                args, len(self.vocab))
            self.embedding.load_state_dict(ckpt['state_dict_embedding'])
            self.model_from_embeddings.load_state_dict(
                ckpt['state_dict_model_from_embeddings'])
            self.checkpoint = ckpt['epoch']
        else:
            self.embedding = torch.nn.Embedding(len(self.vocab),
                                                self.embedding_size)
            self.model_from_embeddings = LSTMFromEmbeddings(
                args, len(self.vocab))
            self.model = self.embedding, LSTMFromEmbeddings(
                args, len(self.vocab))
            logger.info("Model initialized")
        self.embedding = self.embedding.to(self.device)
        self.model_from_embeddings = self.model_from_embeddings.to(self.device)
        self.word_embeddings = self.embedding
예제 #30
0
def train(epoch):
    assert (optimizer is not None)
    if epoch <= args.num_epochs_all_nodes:
        train_batches = get_batches(data_train_warmup, args.batch_size)
    else:
        train_batches = get_batches(data_train, args.batch_size)
    avg_acc.reset()
    avg_loss.reset()
    avg_acc_robust.reset()
    avg_loss_robust.reset()
    if args.robust:
        eps_inc_per_step = 1.0 / (args.num_epochs_warmup * len(train_batches))
    for i, batch in enumerate(train_batches):
        if args.robust:
            eps = min(
                eps_inc_per_step * ((epoch - -1) * len(train_batches) + i + 1),
                1.0)
        else:
            eps = 0.
        acc, loss, acc_robust, loss_robust = \
            step(model, ptb, batch, eps=eps, train=True)
        avg_acc.update(acc, len(batch))
        avg_loss.update(loss, len(batch))
        avg_acc_robust.update(acc_robust, len(batch))
        avg_loss_robust.update(loss_robust, len(batch))
        if (i + 1) % args.gradient_accumulation_steps == 0 or (
                i + 1) == len(train_batches):
            scale_gradients(optimizer,
                            i % args.gradient_accumulation_steps + 1)
            optimizer.step()
            optimizer.zero_grad()
        if (i + 1) % args.log_interval == 0:
            logger.info(
                "Epoch {}, training step {}/{}: acc {:.3f}, loss {:.3f}, acc_robust {:.3f}, loss_robust {:.3f}, eps {:.3f}"
                .format(epoch, i + 1, len(train_batches), avg_acc.avg,
                        avg_loss.avg, avg_acc_robust.avg, avg_loss_robust.avg,
                        eps))
    model.save(epoch)