Beispiel #1
0
class CometLogger(BaseLogger):
    def __init__(self, experiment_id=None):
        self.experiment = Experiment(auto_metric_logging=False)
        if experiment_id is not None:
            self.experiment.log_parameter('experiment_id', experiment_id)

    def add_scalar(self, name, value, step):
        self.experiment.log_metric(name, value, epoch=step)

    def log_parameters(self, params_dict):
        self.experiment.log_parameters(params_dict)

    def log_metrics(self, metrics_dict, epoch):
        self.experiment.log_metrics(metrics_dict, epoch=epoch)

    def add_text(self, name, text):
        self.experiment.log_text(f'{name}: {text}')

    def set_context_prefix(self, prefix):
        self.experiment.context = prefix

    def reset_context_prefix(self):
        self.experiment.context = None
Beispiel #2
0
def very_simple_param_count(model):
    result = sum([p.numel() for p in model.parameters()])
    return result


if __name__ == "__main__":
    torch.set_num_threads(4) # 4 thread
    experiment = Experiment(project_name=PROJECT_NAME, api_key=COMET_ML_API)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
    args = meow_parse()
    print(args)

    experiment.set_name(args.task_id)
    experiment.set_cmd_args()
    experiment.log_text(args.note)

    DATA_PATH = args.input
    TRAIN_PATH = os.path.join(DATA_PATH, "train_data")
    TEST_PATH = os.path.join(DATA_PATH, "test_data")
    dataset_name = args.datasetname
    if dataset_name=="shanghaitech":
        print("will use shanghaitech dataset with crop ")
    elif dataset_name == "shanghaitech_keepfull":
        print("will use shanghaitech_keepfull")
    else:
        print("cannot detect dataset_name")
        print("current dataset_name is ", dataset_name)

    # create list
    train_list = create_image_list(TRAIN_PATH)
Beispiel #3
0
class CorefSolver():
    def __init__(self, args):
        self.args = args
        self.data_utils = data_utils(args)
        self.disable_comet = args.disable_comet
        self.model = self.make_model(
            src_vocab=self.data_utils.vocab_size,
            tgt_vocab=self.data_utils.vocab_size,
            N=args.num_layer,
            dropout=args.dropout,
            entity_encoder_type=args.entity_encoder_type)
        print(self.model)
        if self.args.train:
            self.outfile = open(self.args.logfile, 'w')
            self.model_dir = make_save_dir(args.model_dir)
            # self.logfile = os.path.join(args.logdir, args.exp_name)
            # self.log = SummaryWriter(self.logfile)
            self.w_valid_file = args.w_valid_file

    def make_model(self,
                   src_vocab,
                   tgt_vocab,
                   N=6,
                   dropout=0.1,
                   d_model=512,
                   entity_encoder_type='linear',
                   d_ff=2048,
                   h=8):

        "Helper: Construct a model from hyperparameters."
        c = copy.deepcopy
        attn = MultiHeadedAttention(h, d_model)
        attn_ner = MultiHeadedAttention(1, d_model, dropout)
        ff = PositionwiseFeedForward(d_model, d_ff, dropout)
        position = PositionalEncoding(d_model, dropout)
        embed = Embeddings(d_model, src_vocab)
        word_embed = nn.Sequential(embed, c(position))
        print('pgen', self.args.pointer_gen)

        if entity_encoder_type == 'transformer':
            # entity_encoder = nn.Sequential(embed, Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), 1))
            print('transformer')
            entity_encoder = Seq_Entity_Encoder(
                embed,
                Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), 2))
        elif entity_encoder_type == 'albert':
            albert_tokenizer = AlbertTokenizer.from_pretrained(
                'albert-base-v2')
            albert = AlbertModel.from_pretrained('albert-base-v2')
            entity_encoder = Albert_Encoder(albert, albert_tokenizer, d_model)
        elif entity_encoder_type == 'gru':
            entity_encoder = RNNEncoder(embed,
                                        'GRU',
                                        d_model,
                                        d_model,
                                        num_layers=1,
                                        dropout=0.1,
                                        bidirectional=True)
            print('gru')
        elif entity_encoder_type == 'lstm':
            entity_encoder = RNNEncoder(embed,
                                        'LSTM',
                                        d_model,
                                        d_model,
                                        num_layers=1,
                                        dropout=0.1,
                                        bidirectional=True)
            print('lstm')

        if self.args.ner_at_embedding:
            model = EncoderDecoderOrg(
                Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
                DecoderOrg(
                    DecoderLayerOrg(d_model, c(attn), c(attn), c(ff), dropout),
                    N, d_model, tgt_vocab, self.args.pointer_gen), word_embed,
                word_embed, entity_encoder)
        else:
            if self.args.ner_last:
                decoder = Decoder(
                    DecoderLayer(d_model, c(attn), c(attn), c(ff),
                                 dropout), N, d_model, tgt_vocab,
                    self.args.pointer_gen, self.args.ner_last)
            else:
                decoder = Decoder(
                    DecoderLayer_ner(d_model, c(attn), c(attn), attn_ner,
                                     c(ff), dropout, self.args.fusion), N,
                    d_model, tgt_vocab, self.args.pointer_gen,
                    self.args.ner_last)
            model = EncoderDecoder(
                Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
                decoder, word_embed, word_embed, entity_encoder)

        # This was important from their code.
        # Initialize parameters with Glorot / fan_avg.
        for p in model.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

        # levels = 3
        # num_chans = [d_model] * (args.levels)
        # k_size = 5
        # tcn = TCN(embed, d_model, num_channels, k_size, dropout=dropout)

        return model.cuda()

    def train(self):
        if not self.disable_comet:
            # logging
            hyper_params = {
                "num_layer": self.args.num_layer,
                "pointer_gen": self.args.pointer_gen,
                "ner_last": self.args.ner_last,
                "entity_encoder_type": self.args.entity_encoder_type,
                "fusion": self.args.fusion,
                "dropout": self.args.dropout,
            }
            COMET_PROJECT_NAME = 'summarization'
            COMET_WORKSPACE = 'timchen0618'

            self.exp = Experiment(
                api_key='mVpNOXSjW7eU0tENyeYiWZKsl',
                project_name=COMET_PROJECT_NAME,
                workspace=COMET_WORKSPACE,
                auto_output_logging='simple',
                auto_metric_logging=None,
                display_summary=False,
            )
            self.exp.log_parameters(hyper_params)
            self.exp.add_tags([
                '%s entity_encoder' % self.args.entity_encoder_type,
                self.args.fusion
            ])
            if self.args.ner_last:
                self.exp.add_tag('ner_last')
            if self.args.ner_at_embedding:
                self.exp.add_tag('ner_at_embedding')
            self.exp.set_name(self.args.exp_name)
            self.exp.add_tag('coreference')

        print('ner_last ', self.args.ner_last)
        print('ner_at_embedding', self.args.ner_at_embedding)
        # dataloader & optimizer
        data_yielder = self.data_utils.data_yielder(num_epoch=100)
        optim = torch.optim.Adam(self.model.parameters(),
                                 lr=1e-7,
                                 betas=(0.9, 0.998),
                                 eps=1e-8,
                                 amsgrad=True)  #get_std_opt(self.model)
        # entity_optim = torch.optim.Adam(self.entity_encoder.parameters(), lr=1e-7, betas=(0.9, 0.998), eps=1e-8, amsgrad=True)
        total_loss = []
        start = time.time()
        print('*' * 50)
        print('Start Training...')
        print('*' * 50)
        start_step = 0

        # if loading from checkpoint
        if self.args.load_model:
            state_dict = torch.load(self.args.load_model)['state_dict']
            self.model.load_state_dict(state_dict)
            print("Loading model from " + self.args.load_model + "...")
            # encoder_state_dict = torch.load(self.args.entity_encoder)['state_dict']
            # self.entity_encoder.load_state_dict(encoder_state_dict)
            # print("Loading entity_encoder from %s" + self.args.entity_encoder + "...")
            start_step = int(torch.load(self.args.load_model)['step'])
            print('Resume training from step %d ...' % start_step)

        warmup_steps = 10000
        d_model = 512
        lr = 1e-7
        for step in range(start_step, self.args.total_steps):
            self.model.train()
            batch = data_yielder.__next__()
            optim.zero_grad()
            # entity_optim.zero_grad()

            #update lr
            if step % 400 == 1:
                lr = (1 / (d_model**0.5)) * min(
                    (1 / (step / 4)**0.5), step * (1 / (warmup_steps**1.5)))
                for param_group in optim.param_groups:
                    param_group['lr'] = lr
                # for param_group in entity_optim.param_groups:
                #     param_group['lr'] = lr

            batch['src'] = batch['src'].long()
            batch['tgt'] = batch['tgt'].long()
            batch['ner'] = batch['ner'].long()
            batch['src_extended'] = batch['src_extended'].long()

            # forward the model
            if self.args.entity_encoder_type == 'albert':
                d = self.model.entity_encoder.tokenizer.batch_encode_plus(
                    batch['ner_text'],
                    return_attention_masks=True,
                    max_length=10,
                    add_special_tokens=False,
                    pad_to_max_length=True,
                    return_tensors='pt')
                ner_mask = d['attention_mask'].cuda().unsqueeze(1)
                ner = d['input_ids'].cuda()
                # print('ner', ner.size())
                # print('ner_mask', ner_mask.size())
                # print('src_mask', batch['src_mask'].size())

            if self.args.entity_encoder_type == 'gru' or self.args.entity_encoder_type == 'lstm':
                ner_feat = self.model.entity_encoder(
                    batch['ner'].transpose(0, 1), batch['cluster_len'])[1]
            elif self.args.entity_encoder_type == 'transformer':
                mask = gen_mask(batch['cluster_len'])
                ner_feat = self.model.entity_encoder(batch['ner'], mask)
            ner, ner_mask = self.data_utils.pad_ner_feature(
                ner_feat.squeeze(), batch['num_clusters'],
                batch['src'].size(0))
            # print('ner', ner.size())
            # print('ner_mask', ner_mask.size())

            if self.args.ner_at_embedding:
                out = self.model.forward(batch['src'], batch['tgt'], ner,
                                         batch['src_mask'], batch['tgt_mask'],
                                         batch['src_extended'],
                                         len(batch['oov_list']))
            else:
                out = self.model.forward(batch['src'], batch['tgt'], ner,
                                         batch['src_mask'], batch['tgt_mask'],
                                         batch['src_extended'],
                                         len(batch['oov_list']), ner_mask)
            # print out info
            pred = out.topk(1, dim=-1)[1].squeeze().detach().cpu().numpy()[0]
            gg = batch['src_extended'].long().detach().cpu().numpy()[0][:100]
            tt = batch['tgt'].long().detach().cpu().numpy()[0]
            yy = batch['y'].long().detach().cpu().numpy()[0]

            #compute loss & update
            loss = self.model.loss_compute(out, batch['y'].long())
            loss.backward()
            optim.step()
            # entity_optim.step()

            total_loss.append(loss.detach().cpu().numpy())

            # logging information
            if step % self.args.print_every_steps == 1:
                elapsed = time.time() - start
                print("Epoch Step: %d Loss: %f Time: %f lr: %6.6f" %
                      (step, np.mean(total_loss), elapsed,
                       optim.param_groups[0]['lr']))
                self.outfile.write("Epoch Step: %d Loss: %f Time: %f\n" %
                                   (step, np.mean(total_loss), elapsed))
                print(
                    'src:\n',
                    self.data_utils.id2sent(gg, False, False,
                                            batch['oov_list']))
                print(
                    'tgt:\n',
                    self.data_utils.id2sent(yy, False, False,
                                            batch['oov_list']))
                print(
                    'pred:\n',
                    self.data_utils.id2sent(pred, False, False,
                                            batch['oov_list']))
                print('oov_list:\n', batch['oov_list'])

                if ner_mask != None and not self.args.ner_at_embedding:
                    pp = self.model.greedy_decode(
                        batch['src_extended'].long()[:1], ner[:1],
                        batch['src_mask'][:1], 100, self.data_utils.bos,
                        len(batch['oov_list']), self.data_utils.vocab_size,
                        True, ner_mask[:1])
                else:
                    pp = self.model.greedy_decode(
                        batch['src_extended'].long()[:1], ner[:1],
                        batch['src_mask'][:1], 100, self.data_utils.bos,
                        len(batch['oov_list']), self.data_utils.vocab_size,
                        True)

                pp = pp.detach().cpu().numpy()
                print(
                    'pred_greedy:\n',
                    self.data_utils.id2sent(pp[0], False, False,
                                            batch['oov_list']))

                print()
                start = time.time()
                if not self.disable_comet:
                    # self.log.add_scalar('Loss/train', np.mean(total_loss), step)
                    self.exp.log_metric('Train Loss',
                                        np.mean(total_loss),
                                        step=step)
                    self.exp.log_metric('Learning Rate',
                                        optim.param_groups[0]['lr'],
                                        step=step)

                    self.exp.log_text('Src: ' + self.data_utils.id2sent(
                        gg, False, False, batch['oov_list']))
                    self.exp.log_text('Tgt:' + self.data_utils.id2sent(
                        yy, False, False, batch['oov_list']))
                    self.exp.log_text('Pred:' + self.data_utils.id2sent(
                        pred, False, False, batch['oov_list']))
                    self.exp.log_text('Pred Greedy:' + self.data_utils.id2sent(
                        pp[0], False, False, batch['oov_list']))
                    self.exp.log_text('OOV:' + ' '.join(batch['oov_list']))

                total_loss = []

            ##########################
            # validation
            ##########################
            if step % self.args.valid_every_steps == 2:
                print('*' * 50)
                print('Start Validation...')
                print('*' * 50)
                self.model.eval()
                val_yielder = self.data_utils.data_yielder(1, valid=True)
                total_loss = []
                fw = open(self.w_valid_file, 'w')
                for batch in val_yielder:
                    with torch.no_grad():
                        batch['src'] = batch['src'].long()
                        batch['tgt'] = batch['tgt'].long()
                        batch['ner'] = batch['ner'].long()
                        batch['src_extended'] = batch['src_extended'].long()

                        ### ner ######
                        if self.args.entity_encoder_type == 'albert':
                            d = self.model.entity_encoder.tokenizer.batch_encode_plus(
                                batch['ner_text'],
                                return_attention_masks=True,
                                max_length=10,
                                add_special_tokens=False,
                                pad_to_max_length=True,
                                return_tensors='pt')
                            ner_mask = d['attention_mask'].cuda().unsqueeze(1)
                            ner = d['input_ids'].cuda()

                        if self.args.entity_encoder_type == 'gru' or self.args.entity_encoder_type == 'lstm':
                            ner_feat = self.model.entity_encoder(
                                batch['ner'].transpose(0, 1),
                                batch['cluster_len'])[1]
                        elif self.args.entity_encoder_type == 'transformer':
                            mask = gen_mask(batch['cluster_len'])
                            ner_feat = self.model.entity_encoder(
                                batch['ner'], mask)
                        ner, ner_mask = self.data_utils.pad_ner_feature(
                            ner_feat.squeeze(), batch['num_clusters'],
                            batch['src'].size(0))
                        ### ner ######

                        if self.args.ner_at_embedding:
                            out = self.model.forward(batch['src'],
                                                     batch['tgt'], ner,
                                                     batch['src_mask'],
                                                     batch['tgt_mask'],
                                                     batch['src_extended'],
                                                     len(batch['oov_list']))
                        else:
                            out = self.model.forward(batch['src'],
                                                     batch['tgt'], ner,
                                                     batch['src_mask'],
                                                     batch['tgt_mask'],
                                                     batch['src_extended'],
                                                     len(batch['oov_list']),
                                                     ner_mask)
                        loss = self.model.loss_compute(out, batch['y'].long())
                        total_loss.append(loss.item())

                        if self.args.ner_at_embedding:
                            pred = self.model.greedy_decode(
                                batch['src_extended'].long(), ner,
                                batch['src_mask'], self.args.max_len,
                                self.data_utils.bos, len(batch['oov_list']),
                                self.data_utils.vocab_size)
                        else:
                            pred = self.model.greedy_decode(
                                batch['src_extended'].long(),
                                ner,
                                batch['src_mask'],
                                self.args.max_len,
                                self.data_utils.bos,
                                len(batch['oov_list']),
                                self.data_utils.vocab_size,
                                ner_mask=ner_mask)

                        for l in pred:
                            sentence = self.data_utils.id2sent(
                                l[1:], True, self.args.beam_size != 1,
                                batch['oov_list'])
                            fw.write(sentence)
                            fw.write("\n")
                fw.close()
                # files_rouge = FilesRouge()
                # scores = files_rouge.get_scores(self.w_valid_file, self.args.valid_tgt_file, avg=True)
                scores = cal_rouge_score(self.w_valid_file,
                                         self.args.valid_ref_file)
                r1_score = scores['rouge1']
                r2_score = scores['rouge2']

                print('=============================================')
                print('Validation Result -> Loss : %6.6f' %
                      (sum(total_loss) / len(total_loss)))
                print(scores)
                print('=============================================')
                self.outfile.write(
                    '=============================================\n')
                self.outfile.write('Validation Result -> Loss : %6.6f\n' %
                                   (sum(total_loss) / len(total_loss)))
                self.outfile.write(
                    '=============================================\n')
                # self.model.train()
                # self.log.add_scalar('Loss/valid', sum(total_loss)/len(total_loss), step)
                # self.log.add_scalar('Score/valid', r1_score, step)
                if not self.disable_comet:
                    self.exp.log_metric('Valid Loss',
                                        sum(total_loss) / len(total_loss),
                                        step=step)
                    self.exp.log_metric('R1 Score', r1_score, step=step)
                    self.exp.log_metric('R2 Score', r2_score, step=step)

                #Saving Checkpoint
                w_step = int(step / 10000)
                print('Saving ' + str(w_step) + 'w_model.pth!\n')
                self.outfile.write('Saving ' + str(w_step) + 'w_model.pth\n')

                model_name = str(w_step) + 'w_' + '%6.6f' % (
                    sum(total_loss) / len(total_loss)
                ) + '%2.3f_' % r1_score + '%2.3f_' % r2_score + 'model.pth'
                state = {'step': step, 'state_dict': self.model.state_dict()}
                torch.save(state, os.path.join(self.model_dir, model_name))

                # entity_encoder_name = str(w_step) + '0w_' + '%6.6f'%(sum(total_loss)/len(total_loss)) + '%2.3f_'%r1_score + 'entity_encoder.pth'
                # state = {'step': step, 'state_dict': self.entity_encoder.state_dict()}
                # torch.save(state, os.path.join(self.model_dir, entity_encoder_name))

    def test(self):
        #prepare model
        path = self.args.load_model
        # entity_encoder_path = self.args.entity_encoder
        state_dict = torch.load(path)['state_dict']
        max_len = self.args.max_len
        model = self.model
        model.load_state_dict(state_dict)

        # entity_encoder_dict = torch.load(entity_encoder_path)['state_dict']
        # self.entity_encoder.load_state_dict(entity_encoder_dict)

        pred_dir = make_save_dir(self.args.pred_dir)
        filename = self.args.filename

        #start decoding
        data_yielder = self.data_utils.data_yielder(num_epoch=1)
        total_loss = []
        start = time.time()

        #file
        f = open(os.path.join(pred_dir, filename), 'w')

        self.model.eval()

        # decode_strategy = BeamSearch(
        #             self.beam_size,
        #             batch_size=batch.batch_size,
        #             pad=self._tgt_pad_idx,
        #             bos=self._tgt_bos_idx,
        #             eos=self._tgt_eos_idx,
        #             n_best=self.n_best,
        #             global_scorer=self.global_scorer,
        #             min_length=self.min_length, max_length=self.max_length,
        #             return_attention=attn_debug or self.replace_unk,
        #             block_ngram_repeat=self.block_ngram_repeat,
        #             exclusion_tokens=self._exclusion_idxs,
        #             stepwise_penalty=self.stepwise_penalty,
        #             ratio=self.ratio)

        step = 0
        for batch in data_yielder:
            #print(batch['src'].data.size())
            step += 1
            if step % 100 == 0:
                print('%d batch processed. Time elapsed: %f min.' %
                      (step, (time.time() - start) / 60.0))
                start = time.time()

            ### ner ###
            if self.args.entity_encoder_type == 'albert':
                d = self.model.entity_encoder.tokenizer.batch_encode_plus(
                    batch['ner_text'],
                    return_attention_masks=True,
                    max_length=10,
                    add_special_tokens=False,
                    pad_to_max_length=True,
                    return_tensors='pt')
                ner_mask = d['attention_mask'].cuda().unsqueeze(1)
                ner = d['input_ids'].cuda()
            else:
                ner_mask = None
                ner = batch['ner'].long()

            with torch.no_grad():
                if self.args.beam_size == 1:
                    if self.args.ner_at_embedding:
                        out = self.model.greedy_decode(
                            batch['src_extended'].long(),
                            self.model.entity_encoder(ner),
                            batch['src_mask'], max_len, self.data_utils.bos,
                            len(batch['oov_list']), self.data_utils.vocab_size)
                    else:
                        out = self.model.greedy_decode(
                            batch['src_extended'].long(),
                            self.model.entity_encoder(ner),
                            batch['src_mask'],
                            max_len,
                            self.data_utils.bos,
                            len(batch['oov_list']),
                            self.data_utils.vocab_size,
                            ner_mask=ner_mask)
                else:
                    ret = self.beam_decode(batch, max_len,
                                           len(batch['oov_list']))
                    out = ret['predictions']
            for l in out:
                sentence = self.data_utils.id2sent(l[1:], True,
                                                   self.args.beam_size != 1,
                                                   batch['oov_list'])
                #print(l[1:])
                f.write(sentence)
                f.write("\n")

    def beam_decode(self, batch, max_len, oov_nums):

        src = batch['src'].long()
        src_mask = batch['src_mask']
        src_extended = batch['src_extended'].long()

        bos_token = self.data_utils.bos
        beam_size = self.args.beam_size
        vocab_size = self.data_utils.vocab_size
        batch_size = src.size(0)

        def rvar(a):
            return a.repeat(beam_size, 1, 1)

        def rvar2(a):
            return a.repeat(beam_size, 1)

        def bottle(m):
            return m.view(batch_size * beam_size, -1)

        def unbottle(m):
            return m.view(beam_size, batch_size, -1)

        ### ner ###
        if self.args.entity_encoder_type == 'albert':
            d = self.model.entity_encoder.tokenizer.batch_encode_plus(
                batch['ner_text'],
                return_attention_masks=True,
                max_length=10,
                add_special_tokens=False,
                pad_to_max_length=True,
                return_tensors='pt')
            ner_mask = d['attention_mask'].cuda().unsqueeze(1)
            ner = d['input_ids'].cuda()
        else:
            ner_mask = None
            ner = batch['ner'].long()
        ner = self.model.entity_encoder(ner)

        if self.args.ner_at_embedding:
            memory = self.model.encode(src, src_mask, ner)
        else:
            memory = self.model.encode(src, src_mask)

        assert batch_size == 1

        beam = [
            Beam(beam_size,
                 self.data_utils.pad,
                 bos_token,
                 self.data_utils.eos,
                 min_length=self.args.min_length) for i in range(batch_size)
        ]
        memory = rvar(memory)
        ner = rvar(ner)
        src_mask = rvar(src_mask)
        src_extended = rvar2(src_extended)

        for i in range(self.args.max_len):
            if all((b.done() for b in beam)):
                break
            # Construct batch x beam_size nxt words.
            # Get all the pending current beam words and arrange for forward.
            inp = torch.stack([b.get_current_state()
                               for b in beam]).t().contiguous().view(-1, 1)
            #inp -> [1, 3]
            inp_mask = inp < self.data_utils.vocab_size
            inp = inp * inp_mask.long()

            decoder_input = inp

            if self.args.ner_at_embedding:
                final_dist = self.model.decode(memory, ner, src_mask,
                                               decoder_input, None,
                                               src_extended, oov_nums)
            else:
                final_dist = self.model.decode(memory,
                                               ner,
                                               src_mask,
                                               decoder_input,
                                               None,
                                               src_extended,
                                               oov_nums,
                                               ner_mask=ner_mask)
            # final_dist, decoder_hidden, attn_dist_p, p_gen = self.seq2seq_model.model_copy.decoder(
            #                 decoder_input, decoder_hidden,
            #                 post_encoder_outputs, post_enc_padding_mask,
            #                 extra_zeros, post_enc_batch_extend_vocab
            #                 )
            # # Run one step.

            # print('inp', inp.size())

            # decoder_outputs: beam x rnn_size

            # (b) Compute a vector of batch*beam word scores.
            out = unbottle(final_dist)
            out[:, :, 2] = 0  #no unk
            # out.size -> [3, 1, vocab]

            # (c) Advance each beam.
            for j, b in enumerate(beam):
                b.advance(out[:, j])
                # decoder_hidden = self.beam_update(j, b.get_current_origin(), beam_size, decoder_hidden)

        # (4) Extract sentences from beam.
        ret = self._from_beam(beam)

        return ret

    def _from_beam(self, beam):
        ret = {"predictions": [], "scores": []}
        for b in beam:

            n_best = self.args.n_best
            scores, ks = b.sort_finished(minimum=n_best)
            hyps = []
            for i, (times, k) in enumerate(ks[:n_best]):
                hyp = b.get_hyp(times, k)
                hyps.append(hyp)

            ret["predictions"].append(hyps)
            ret["scores"].append(scores)

        return ret
            'weight_decay':
            1e-6
        }, {
            'params': [
                p for n, p in model_caption.named_parameters()
                if n in word_embedding_param_names
            ],
            'weight_decay':
            0.0
        }],
        lr=1e-3)

    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch_bak = checkpoint['epoch'] + 1

    experiment.log_text("MOBILENET captions")

    val_loss_old = 0.0
    val_loss = 0.0
    step_count = 0

    with experiment.train():
        for epoch in range(0, params['epochs']):
            if step_count <= 3:
                temporal_loss = 0.0
                caption_loss = 0.0
                running_loss = 0.0
                total_batches = 0.0

                for i, data in enumerate(tqdm(train_dl, file=sys.stdout)):
                    # zero the parameter gradients
Beispiel #5
0
class experiment_logger:
    '''
    Interface for logging experiments on neptune, comet, or both.
    Args: log_backend, project_name)
    Other backends may also be added in the future
    Currently defined methods:
        add_params:
        add_tags:
        log_text: strings
        log_metrics: numerical values
        log_figure: pyplot figures
        
        stop: end logging and close connection
    '''
    def __init__(self, log_backend, project_name):
        '''

        Parameters
        ----------
        log_backend : STR
            One of 'comet', 'neptune', 'all'
        project_name : STR
            one of available proyects ('yeast', 'jersey', 'wheat', 'debug', etc)
            
        Returns
        -------
        None.

        '''
        self.proj_name = project_name
        self.backend = log_backend
        #Bool indicating wether neptune logging is enabled
        self.neptune = log_backend=='neptune' or log_backend=='all'
        #Bool indicating wether comet logging is enabled
        self.comet = log_backend=='comet' or log_backend=='all'
        if self.neptune:
            neptune.init("dna-i/"+project_name, 
                         api_token='eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vdWkubmVwdHVuZS5haSIsImFwaV91cmwiOiJodHRwczovL3VpLm5lcHR1bmUuYWkiLCJhcGlfa2V5IjoiMWYzMzhjMjItYjczNC00NzZhLWFlZTYtOTI2NzE5MzUwZmNkIn0=')
            print("logging experiments on neptune project "+project_name)
            neptune.create_experiment()
        if self.comet:
            self.comet_experiment = Experiment(api_key="V0OXnWOi4KVNS4OkwLjdnxSgK",
                            project_name=project_name, workspace="dna-i")
            print("logging experiments on comet project "+project_name)
        if not (self.neptune or self.comet):
            raise ValueError('Logging Backend NOT Available')    
    def add_params(self, params, step=None ):
        '''
        Adds parameters to experiment log

        Parameters
        ----------
        params : Dict
            Key-Value pairs

        Returns
        -------
        None.

        '''
        if self.neptune:
            for key, value in params.items():
                neptune.set_property(key, value)   
            if step is not None:
               neptune.set_property('step', step)
        if self.comet:
            self.comet_experiment.log_parameters(params,step=step)
    def add_tags(self, tags):
        '''
        Adds parameters to experiment log

        Parameters
        ----------
        params : tags
            list of tags (strings)
            e.g.: ['tag1', 'tag2']
            
        Returns
        -------
        None.

        '''
        if self.neptune:
            neptune.append_tag(tags)   
        if self.comet:
            self.comet_experiment.add_tags(tags)
 
    def log_metrics(self, name, value, epoch=None):
        '''
        Logging pointwise metrics

        Parameters
        ----------
        name : STR
            Metric key
        value : Float/Integer/(Boolean/String)
            Comet also allows Boolean/string
            Tuples are lallowed
        epoch: (OPT)  INT
            Epoch - or anything used as x axis when plotting metrics

        Returns
        -------
        None.

        '''
        if self.neptune:
            try:
                if epoch is not None:
                    if type(value) is tuple:
                        print("Logging tuple as r and p-value")
                        for val, n in zip(value, [" (r)", " (p-val)"]):
                            neptune.log_metric(name  + n,epoch,y=val)
                    else:
                        neptune.log_metric(name, epoch, y=value)
                else:
                    if type(value) is tuple:
                        print("Logging tuple as r and p-value")
                        for val, n in zip(value, [" (r)", " (p-val)"]):
                            neptune.log_metric(name+n, val)
                    else:
                        neptune.log_metric(name, value)
            except:
                print("Metric type {} not supported by neptune.".format(type(value)))
                print("logging as text")
                self.log_text( "{}".format(value), key=name)
                
        if self.comet:    
            try:
                if epoch is not None:
                    if type(value) is tuple:
                        print("Logging tuple as r and p-value")
                        for val, n in zip(value, [" (r)", " (p-val)"]):
                            self.comet_experiment.log_metric(name+n, val, step=int(epoch))
                    else:
                        self.comet_experiment.log_metric(name, value, epoch=epoch)
                else:
                    if type(value) is tuple:
                        print("Logging tuple as r and p-value")
                        for val, n in zip(value, [" (r)", " (p-val)"]):
                            self.comet_experiment.log_metric(name+n, val)
                    else:
                        self.comet_experiment.log_metric(name, value)
            except:
                print("Metric type {} not supported by comet.".format(type(value)))
                if type(value) is tuple:
                    print("Logging tuple as x-y pairs")
                    for idx, val in enumerate(value):
                        self.comet_experiment.log_metric(name, val, epoch=idx) 
                else:
                    print("Logging as other.")
                    self.comet_experiment.log_other(name, value)
                
    def log_text(self, string, key=None, epoch=None):
        '''
          Logs text strings

          Parameters
          ----------
          string : STR
              text to  log
          key: STR
              log_name needed for Neptune strings 
          epoch: INT
              epoch or any other index
          
          Returns
          -------
          None.

        '''
        if self.neptune:
            if type(string) is str:
                if key is None:
                    print('Neptune log_name needed for logging text')
                    print('Using a dummy name: text')
                    neptune.log_text('text', string)
                if epoch is None:
                    neptune.log_text(key, string)
                else:
                    neptune.log_text(key, epoch, y=string)        
            else:
                print("Wrong type: logging text must be a string")
        if self.comet:                
            if type(string) is str:
                if key is not None:
                    print("Commet text logging does not  support keys, prepending it to text")
                    string = key+ ', '+string
                if epoch is None:
                    self.comet_experiment.log_text(string)
                else:
                    self.comet_experiment.log_text(string, step=epoch)
            else:
                print("Wrong type: logging text must be a string")
        
    def log_figure(self, figure=None, figure_name=None, step=None):
        '''
        Logs pyplot figure

        Parameters
        ----------
        figure : pyplot figure, optional in comet mandatory in neptune.
            The default is None, uses global pyplot figure.
        figure_name : STR, optional in comet mandatory in neptune.
             The default is None.
        step : INT, optional
            An index. The default is None.

        Returns
        -------
        None.

        '''
        if self.neptune:
            if figure is not None:
                if figure_name is None:
                    print("Figure name must be given to neptune logger")
                    print("Using dummy name: figure")
                    figure_name = 'figure'
                if step is None:
                    neptune.log_image(figure_name, figure)
                else:
                    neptune.log_image(figure_name, step, y=figure)    
            else:
                print("A figure must be passed to neptune logger")
        if self.comet:    
            self.comet_experiment.log_figure(figure_name=figure_name, figure=figure, step=step) 
    def stop(self):
        if self.neptune:
            neptune.stop()
        if self.comet:
            self.comet_experiment.end()
        
    def add_table(self, filename, tabular_data=None, headers=False):
        
        self.comet_experiment.log_table(filename, tabular_data, headers)
        
    def log_image(self, image=None, figure_name=None, step=None):
        '''
        Logs pyplot figure

        Parameters
        ----------
        figure : pyplot figure, optional in comet mandatory in neptune.
            The default is None, uses global pyplot figure.
        figure_name : STR, optional in comet mandatory in neptune.
             The default is None.
        step : INT, optional
            An index. The default is None.

        Returns
        -------
        None.

        '''
        self.log_image(image, name=figure_name, overwrite=False, image_format="png", image_scale=1.0, \
                       image_shape=None, image_colormap=None, image_minmax=None, image_channels="last", \
                       copy_to_tmp=True, step=step)
    
    
    
    
    def log_hist3d(self, values=None, figure_name=None, step=None):
        '''
        Logs pyplot figure
    
        Parameters
        ----------
        figure : pyplot figure, optional in comet mandatory in neptune.
            The default is None, uses global pyplot figure.
        figure_name : STR, optional in comet mandatory in neptune.
             The default is None.
        step : INT, optional
            An index. The default is None.
    
        Returns
        -------
        None.
    
        '''
        if self.neptune:
            print("not implemented")    
        if self.comet:    
            self.comet_experiment.log_histogram_3d(values, name=figure_name, step=step) 
    
    
    def log_table(self, name=None, data=None, headers=False):
        '''
        

        Parameters
        ----------
        name : str
            Table name
        data : array, list
            
        headers : TYPE, optional
            wether to use headers

        Returns
        -------
        None.

        '''
        self.comet_experiment.log_table(name+'.csv', tabular_data= data, headers = headers )
Beispiel #6
0
class ModelTrainer:
    def __init__(self, model, dataloader, args):
        self.model = model
        self.args = args
        self.data = dataloader
        self.metric = args.metric

        if (dataloader is not None):
            self.frq_log = len(dataloader['train']) // args.frq_log

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        model.to(self.device)

        if args.optimizer == 'sgd':
            self.optimizer = optim.SGD(model.parameters(),
                                       lr=args.lr,
                                       momentum=args.momentum,
                                       weight_decay=args.weight_decay)
        elif args.optimizer == 'adam':
            self.optimizer = optim.Adam(model.parameters(),
                                        lr=args.lr,
                                        betas=(args.beta1, 0.999),
                                        weight_decay=args.weight_decay)
        else:
            raise Exception('--optimizer should be one of {sgd, adam}')

        if args.scheduler == 'set':
            self.scheduler = optim.lr_scheduler.LambdaLR(
                self.optimizer,
                lambda epoch: 10**(epoch / args.scheduler_factor))
        elif args.scheduler == 'auto':
            self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer,
                mode='min',
                factor=args.scheduler_factor,
                patience=5,
                verbose=True,
                threshold=0.0001,
                threshold_mode='rel',
                cooldown=0,
                min_lr=0,
                eps=1e-08)

        self.experiment = Experiment(api_key=args.comet_key,
                                     project_name=args.comet_project,
                                     workspace=args.comet_workspace,
                                     auto_weight_logging=True,
                                     auto_metric_logging=False,
                                     auto_param_logging=False)

        self.experiment.set_name(args.name)
        self.experiment.log_parameters(vars(args))
        self.experiment.set_model_graph(str(self.model))

    def train_one_epoch(self, epoch):

        self.model.train()
        train_loader = self.data['train']
        train_loss = 0
        correct = 0

        comet_offset = epoch * len(train_loader)

        for batch_idx, (data, target) in tqdm(enumerate(train_loader),
                                              leave=True,
                                              total=len(train_loader)):
            data, target = data.to(self.device), target.to(self.device)

            self.optimizer.zero_grad()
            output = self.model(data)
            loss = F.cross_entropy(output, target, reduction='sum')
            loss.backward()
            self.optimizer.step()

            pred = output.argmax(dim=1, keepdim=True)
            acc = pred.eq(target.view_as(pred)).sum().item()
            train_loss += loss.item()
            correct += acc

            loss = loss.item() / len(data)
            acc = 100. * acc / len(data)

            comet_step = comet_offset + batch_idx
            self.experiment.log_metric('batch_loss', loss, comet_step, epoch)
            self.experiment.log_metric('batch_acc', acc, comet_step, epoch)

            if (batch_idx + 1) % self.frq_log == 0:
                self.experiment.log_metric('log_loss', loss, comet_step, epoch)
                self.experiment.log_metric('log_acc', acc, comet_step, epoch)
                print('Epoch: {} [{}/{}]\tLoss: {:.6f}\tAcc: {:.2f}%'.format(
                    epoch + 1, (batch_idx + 1) * len(data),
                    len(train_loader.dataset), loss, acc))

        train_loss /= len(train_loader.dataset)
        acc = 100. * correct / len(train_loader.dataset)

        comet_step = comet_offset + len(train_loader) - 1
        self.experiment.log_metric('loss', train_loss, comet_step, epoch)
        self.experiment.log_metric('acc', acc, comet_step, epoch)

        print(
            'Epoch: {} [Done]\tLoss: {:.4f}\tAccuracy: {}/{} ({:.2f}%)'.format(
                epoch + 1, train_loss, correct, len(train_loader.dataset),
                acc))

        return {'loss': train_loss, 'acc': acc}

    def train(self):

        self.log_cmd()
        best = -1
        history = {'lr': [], 'train_loss': []}

        try:
            print(">> Training %s" % self.model.name)
            for epoch in range(self.args.nepoch):
                with self.experiment.train():
                    train_res = self.train_one_epoch(epoch)

                with self.experiment.validate():
                    print("\nvalidation...")
                    comet_offset = (epoch + 1) * len(self.data['train']) - 1
                    res = self.val(self.data['val'], comet_offset, epoch)

                if res[self.metric] > best:
                    best = res[self.metric]
                    self.save_weights(epoch)

                if self.args.scheduler == 'set':
                    lr = self.optimizer.param_groups[0]['lr']
                    history['lr'].append(lr)
                    history['train_loss'].append(train_res['loss'])

                    self.scheduler.step(epoch + 1)
                    lr = self.optimizer.param_groups[0]['lr']
                    print('learning rate changed to: %.10f' % lr)

                elif self.args.scheduler == 'auto':
                    self.scheduler.step(train_res['loss'])
        finally:
            print(">> Training model %s. [Stopped]" % self.model.name)
            self.experiment.log_asset_folder(os.path.join(
                self.args.outf, self.args.name, 'weights'),
                                             step=None,
                                             log_file_name=False,
                                             recursive=False)
            if self.args.scheduler == 'set':
                plt.semilogx(history['lr'], history['train_loss'])
                plt.grid(True)
                self.experiment.log_figure(figure=plt)
                plt.show()

    def val(self, val_loader, comet_offset=-1, epoch=-1):
        self.model.eval()
        test_loss = 0
        correct = 0

        labels = list(range(self.args.nclass))
        cm = np.zeros((len(labels), len(labels)))

        with torch.no_grad():
            for data, target in tqdm(val_loader,
                                     leave=True,
                                     total=len(val_loader)):
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                test_loss += F.cross_entropy(output, target,
                                             reduction='sum').item()
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()

                pred = pred.view_as(target).data.cpu().numpy()
                target = target.data.cpu().numpy()
                cm += confusion_matrix(target, pred, labels=labels)

        test_loss /= len(val_loader.dataset)
        accuracy = 100. * correct / len(val_loader.dataset)

        print('Evaluation: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.
              format(test_loss, correct, len(val_loader.dataset), accuracy))

        res = {'loss': test_loss, 'acc': accuracy}

        self.experiment.log_metrics(res, step=comet_offset, epoch=epoch)
        self.experiment.log_confusion_matrix(
            matrix=cm,
            labels=[ClassDict.getName(x) for x in labels],
            title='confusion matrix after epoch %03d' % epoch,
            file_name="confusion_matrix_%03d.json" % epoch)

        return res

    def test(self):
        self.load_weights()
        with self.experiment.test():
            print('\ntesting....')
            res = self.val(self.data['test'])

    def log_cmd(self):
        d = vars(self.args)
        cmd = '!python main.py \\\n'
        tab = '    '

        for k, v in d.items():
            if v is None or v == '' or (isinstance(v, bool) and v is False):
                continue

            if isinstance(v, bool):
                arg = '--{} \\\n'.format(k)
            else:
                arg = '--{} {} \\\n'.format(k, v)

            cmd = cmd + tab + arg

        # print(cmd);
        self.experiment.log_text(cmd)

    def save_weights(self, epoch: int):

        weight_dir = os.path.join(self.args.outf, self.args.name, 'weights')
        if not os.path.exists(weight_dir):
            os.makedirs(weight_dir)

        torch.save({
            'epoch': epoch,
            'state_dict': self.model.state_dict()
        }, os.path.join(weight_dir, 'model.pth'))

    def load_weights(self):

        path_g = self.args.weights_path

        if path_g is None:
            weight_dir = os.path.join(self.args.outf, self.args.name,
                                      'weights')
            path_g = os.path.join(weight_dir, 'model.pth')

        print('>> Loading weights...')
        weights_g = torch.load(path_g, map_location=self.device)['state_dict']
        self.model.load_state_dict(weights_g)
        print('   Done.')

    def predict(self, x):
        x = x / 2**15
        self.model.eval()
        with torch.no_grad():
            x = torch.from_numpy(x).float()
            x = self.transform(x)
            x = x.unsqueeze(0)
            x = self.model(x)
            x = F.softmax(x, dim=1)
            x = x.numpy()
        return x
Beispiel #7
0
def get_params():
    parser = argparse.ArgumentParser(description='Perm')
    # Hparams
    padd = parser.add_argument
    padd('--batch-size',
         type=int,
         default=64,
         metavar='N',
         help='input batch size for training (default: 64)')
    padd('--latent_dim',
         type=int,
         default=20,
         metavar='N',
         help='Latent dim for VAE')
    padd('--lr',
         type=float,
         default=0.01,
         metavar='LR',
         help='learning rate (default: 0.01)')
    padd('--momentum',
         type=float,
         default=0.5,
         metavar='M',
         help='SGD momentum (default: 0.5)')
    padd('--latent_size',
         type=int,
         default=50,
         metavar='N',
         help='Size of latent distribution (default: 50)')
    padd('--estimator',
         default='reinforce',
         const='reinforce',
         nargs='?',
         choices=['reinforce', 'lax'],
         help='Grad estimator for noise (default: %(default)s)')
    padd('--reward',
         default='soft',
         const='soft',
         nargs='?',
         choices=['soft', 'hard'],
         help='Reward for grad estimator (default: %(default)s)')

    # Training
    padd('--epochs',
         type=int,
         default=10,
         metavar='N',
         help='number of epochs to train (default: 10)')
    padd('--PGD_steps',
         type=int,
         default=40,
         metavar='N',
         help='max gradient steps (default: 30)')
    padd('--max_iter',
         type=int,
         default=20,
         metavar='N',
         help='max gradient steps (default: 30)')
    padd('--max_batches',
         type=int,
         default=None,
         metavar='N',
         help=
         'max number of batches per epoch, used for debugging (default: None)')
    padd('--epsilon',
         type=float,
         default=0.5,
         metavar='M',
         help='Epsilon for Delta (default: 0.1)')
    padd('--LAMBDA',
         type=float,
         default=100,
         metavar='M',
         help='Lambda for L2 lagrange penalty (default: 0.1)')
    padd('--nn_temp',
         type=float,
         default=1.0,
         metavar='M',
         help='Starting diff. nearest neighbour temp (default: 1.0)')
    padd('--temp_decay_rate',
         type=float,
         default=0.9,
         metavar='M',
         help='Nearest neighbour temp decay rate (default: 0.9)')
    padd('--temp_decay_schedule',
         type=float,
         default=100,
         metavar='M',
         help='How many batches before decay (default: 100)')
    padd('--bb_steps',
         type=int,
         default=2000,
         metavar='N',
         help='Max black box steps per sample(default: 1000)')
    padd('--attack_epochs',
         type=int,
         default=10,
         metavar='N',
         help='Max numbe of epochs to train G')
    padd('--seed',
         type=int,
         default=1,
         metavar='S',
         help='random seed (default: 1)')
    padd('--batch_size', type=int, default=256, metavar='S', help='Batch size')
    padd('--embedding_dim', type=int, default=300, help='embedding_dim')
    padd('--embedding_type',
         type=str,
         default="non-static",
         help='embedding_type')
    padd('--test_batch_size',
         type=int,
         default=128,
         metavar='N',
         help='Test Batch size. 256 requires 12GB GPU memory')
    padd('--test',
         default=False,
         action='store_true',
         help='just test model and print accuracy')
    padd('--deterministic_G',
         default=False,
         action='store_true',
         help='Auto-encoder, no VAE')
    padd('--resample_test',
         default=False,
         action='store_true',
         help='Load model and test resampling capability')
    padd('--resample_iterations',
         type=int,
         default=100,
         metavar='N',
         help='How many times to resample (default: 100)')
    padd('--clip_grad',
         default=True,
         action='store_true',
         help='Clip grad norm')
    padd('--train_vae', default=False, action='store_true', help='Train VAE')
    padd('--train_ae', default=False, action='store_true', help='Train AE')
    padd('--use_flow',
         default=False,
         action='store_true',
         help='Add A NF to Generator')
    padd('--carlini_loss',
         default=False,
         action='store_true',
         help='Use CW loss function')
    padd('--vanilla_G',
         default=False,
         action='store_true',
         help='Vanilla G White Box')
    padd('--prepared_data',
         default='dataloader/prepared_data.pickle',
         help='Test on a single data')

    # Imported Model Params
    padd('--emsize', type=int, default=300, help='size of word embeddings')
    padd('--nhidden',
         type=int,
         default=300,
         help='number of hidden units per layer in LSTM')
    padd('--nlayers', type=int, default=2, help='number of layers')
    padd('--noise_radius',
         type=float,
         default=0.2,
         help='stdev of noise for autoencoder (regularizer)')
    padd('--noise_anneal',
         type=float,
         default=0.995,
         help='anneal noise_radius exponentially by this every 100 iterations')
    padd('--hidden_init',
         action='store_true',
         help="initialize decoder hidden state with encoder's")
    padd('--arch_i',
         type=str,
         default='300-300',
         help='inverter architecture (MLP)')
    padd('--arch_g',
         type=str,
         default='300-300',
         help='generator architecture (MLP)')
    padd('--arch_d',
         type=str,
         default='300-300',
         help='critic/discriminator architecture (MLP)')
    padd('--arch_conv_filters',
         type=str,
         default='500-700-1000',
         help='encoder filter sizes for different convolutional layers')
    padd('--arch_conv_strides',
         type=str,
         default='1-2-2',
         help='encoder strides for different convolutional layers')
    padd('--arch_conv_windows',
         type=str,
         default='3-3-3',
         help='encoder window sizes for different convolutional layers')
    padd('--z_size',
         type=int,
         default=100,
         help='dimension of random noise z to feed into generator')
    padd('--temp',
         type=float,
         default=1,
         help='softmax temperature (lower --> more discrete)')
    padd('--enc_grad_norm',
         type=bool,
         default=True,
         help='norm code gradient from critic->encoder')
    padd('--train_emb', type=bool, default=True, help='Train Glove Embeddings')
    padd('--gan_toenc',
         type=float,
         default=-0.01,
         help='weight factor passing gradient from gan to encoder')
    padd('--dropout',
         type=float,
         default=0.0,
         help='dropout applied to layers (0 = no dropout)')
    padd('--useJS',
         type=bool,
         default=True,
         help='use Jenson Shannon distance')
    padd('--perturb_z',
         type=bool,
         default=True,
         help='perturb noise space z instead of hidden c')
    padd('--max_seq_len', type=int, default=200, help='max_seq_len')
    padd('--gamma', type=float, default=0.95, help='Discount Factor')
    padd('--model',
         type=str,
         default="lstm_arch",
         help='classification model name')
    padd('--distance_func',
         type=str,
         default="cosine",
         help='NN distance function')
    padd('--hidden_dim', type=int, default=128, help='hidden_dim')
    padd('--burn_in', type=int, default=500, help='Train VAE burnin')
    padd('--beta', type=float, default=0., help='Entropy reg')
    padd('--embedding_training',
         type=bool,
         default=False,
         help='embedding_training')
    padd('--seqgan_reward',
         action='store_true',
         default=False,
         help='use seq gan reward')
    padd('--train_classifier',
         action='store_true',
         default=False,
         help='Train Classifier from scratch')
    padd('--diff_nn',
         action='store_true',
         default=False,
         help='Backprop through Nearest Neighbors')
    # Bells
    padd('--no-cuda',
         action='store_true',
         default=False,
         help='disables CUDA training')
    padd('--data_parallel',
         action='store_true',
         default=False,
         help="Use multiple GPUs")
    padd('--save_adv_samples',
         action='store_true',
         default=False,
         help='Write adversarial samples to disk')
    padd('--nearest_neigh_all',
         action='store_true',
         default=False,
         help='Evaluate near. neig. for whole evaluation set')
    padd("--comet",
         action="store_true",
         default=False,
         help='Use comet for logging')
    padd(
        "--offline_comet",
        action="store_true",
        default=False,
        help=
        'Use comet offline. To upload, after training run: comet-upload file.zip'
    )
    padd("--comet_username",
         type=str,
         default="joeybose",
         help='Username for comet logging')
    padd("--comet_apikey", type=str,\
            default="Ht9lkWvTm58fRo9ccgpabq5zV",help='Api for comet logging')
    padd('--debug', default=False, action='store_true', help='Debug')
    padd('--debug_neighbour',
         default=False,
         action='store_true',
         help='Debug nearest neighbour training')
    padd('--load_model',
         default=False,
         action='store_true',
         help='Whether to load a checkpointed model')
    padd('--save_model',
         default=False,
         action='store_true',
         help='Whether to checkpointed model')
    padd('--model_path', type=str, default="saved_models/lstm_torchtext2.pt",\
                        help='where to save/load target model')
    padd('--adv_model_path', type=str, default="saved_models/adv_model.pt",\
                        help='where to save/load adversarial')
    padd('--no_load_embedding',
         action='store_false',
         default=True,
         help='load Glove embeddings')
    padd('--namestr', type=str, default='BMD Text', \
            help='additional info in output filename to describe experiments')
    padd('--dataset', type=str, default="imdb", help='dataset')
    padd('--clip', type=float, default=1, help='gradient clipping, max norm')
    padd('--use_glove', type=str, default="true", help='gpu number')
    args = parser.parse_args()
    args.classes = 2
    args.sample_file = "temp/adv_samples.txt"
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    # Check if settings file
    if os.path.isfile("settings.json"):
        with open('settings.json') as f:
            data = json.load(f)
        args.comet_apikey = data["apikey"]
        args.comet_username = data["username"]

    # Prep file to save adversarial samples
    if args.save_adv_samples:
        now = datetime.datetime.now()
        if os.path.exists(args.sample_file):
            os.remove(args.sample_file)
        with open(args.sample_file, 'w') as f:
            f.write("Adversarial samples starting:\n{}\n".format(now))

    # Comet logging
    args.device = torch.device("cuda" if use_cuda else "cpu")
    if args.comet and not args.offline_comet:
        experiment = Experiment(api_key=args.comet_apikey,
                                project_name="black-magic-design",
                                workspace=args.comet_username)
    elif args.offline_comet:
        offline_path = "temp/offline_comet"
        if not os.path.exists(offline_path):
            os.makedirs(offline_path)
        from comet_ml import OfflineExperiment
        experiment = OfflineExperiment(project_name="black-magic-design",
                                       workspace=args.comet_username,
                                       offline_directory=offline_path)

    # To upload offline comet, run: comet-upload file.zip
    if args.comet or args.offline_comet:
        experiment.set_name(args.namestr)

        def log_text(self, msg):
            # Change line breaks for html breaks
            msg = msg.replace('\n', '<br>')
            self.log_html("<p>{}</p>".format(msg))

        experiment.log_text = MethodType(log_text, experiment)
        args.experiment = experiment

    return args
Beispiel #8
0
        ipdb.set_trace = lambda: None

    # Comet logging
    args.device = torch.device("cuda" if use_cuda else "cpu")
    if args.comet and not args.offline_comet:
        experiment = Experiment(api_key=args.comet_apikey,
                                project_name="black-magic-design",
                                workspace=args.comet_username)
    elif args.offline_comet:
        offline_path = "temp/offline_comet"
        if not os.path.exists(offline_path):
            os.makedirs(offline_path)
        from comet_ml import OfflineExperiment
        experiment = OfflineExperiment(project_name="black-magic-design",
                                       workspace=args.comet_username,
                                       offline_directory=offline_path)

    # To upload offline comet, run: comet-upload file.zip
    if args.comet or args.offline_comet:
        experiment.set_name(args.namestr)

        def log_text(self, msg):
            # Change line breaks for html breaks
            msg = msg.replace('\n', '<br>')
            self.log_html("<p>{}</p>".format(msg))

        experiment.log_text = MethodType(log_text, experiment)
        args.experiment = experiment

    main(args)