def main():
    # init random seed
    init_random_seed(params.manual_seed)

    # Load dataset
    mnist_data_loader = get_usps(train=True, download=True)
    mnist_data_loader_eval = get_usps(train=False, download=True)
    usps_data_loader = get_usps(train=True, download=True)
    usps_data_loader_eval = get_usps(train=False, download=True)

    # Model init ADDA
    src_encoder = model_init(Encoder(), params.src_encoder_adda_rb_path)
    tgt_encoder = model_init(Encoder(), params.tgt_encoder_adda_rb_path)
    critic = model_init(Discriminator(), params.disc_adda_rb_path)
    clf = model_init(Classifier(), params.clf_adda_rb_path)

    # Train source model for adda
    print(
        "====== Robust training source encoder and classifier in MNIST domain ======"
    )
    if not (src_encoder.pretrained and clf.pretrained
            and params.model_trained):
        src_encoder, clf = train_src_robust(src_encoder, clf,
                                            mnist_data_loader)

    # Eval source model
    print("====== Evaluating classifier for MNIST domain ======")
    eval_tgt(src_encoder, clf, mnist_data_loader_eval)

    # Train target encoder
    print("====== Robust training encoder for USPS domain ======")
    # Initialize target encoder's weights with those of the source encoder
    if not tgt_encoder.pretrained:
        tgt_encoder.load_state_dict(src_encoder.state_dict())

    if not (tgt_encoder.pretrained and critic.pretrained
            and params.model_trained):
        tgt_encoder = train_tgt_adda(src_encoder,
                                     tgt_encoder,
                                     clf,
                                     critic,
                                     mnist_data_loader,
                                     usps_data_loader,
                                     usps_data_loader_eval,
                                     robust=True)

    # Eval target encoder on test set of target dataset
    print("====== Ealuating classifier for encoded USPS domain ======")
    print("-------- Source only --------")
    eval_tgt(src_encoder, clf, usps_data_loader_eval)
    print("-------- Domain adaption --------")
    eval_tgt(tgt_encoder, clf, usps_data_loader_eval)
Exemple #2
0
def main():
    # init random seed
    init_random_seed(params.manual_seed)

    # Load dataset
    svhn_data_loader = get_svhn(split='train', download=True)
    svhn_data_loader_eval = get_svhn(split='test', download=True)
    mnist_data_loader = get_mnist(train=True, download=True)
    mnist_data_loader_eval = get_mnist(train=False, download=True)

    # Model init WDGRL
    tgt_encoder = model_init(Encoder(), params.encoder_wdgrl_path)
    critic = model_init(Discriminator(in_dims=params.d_in_dims,
                                      h_dims=params.d_h_dims,
                                      out_dims=params.d_out_dims),
                                        params.disc_wdgrl_path)
    clf = model_init(Classifier(), params.clf_wdgrl_path)

    # Train critic to optimality
    print("====== Training critic ======")
    if not (critic.pretrained and params.model_trained):
        critic = train_critic_wdgrl(tgt_encoder, critic, svhn_data_loader, mnist_data_loader)

    # Train target encoder
    print("====== Training encoder for both SVHN and MNIST domains ======")
    if not (tgt_encoder.pretrained and clf.pretrained and params.model_trained):
        tgt_encoder, clf = train_tgt_wdgrl(tgt_encoder, clf, critic,
                                     svhn_data_loader, mnist_data_loader, robust=False)

    # Eval target encoder on test set of target dataset
    print("====== Evaluating classifier for encoded SVHN and MNIST domains ======")
    print("-------- SVHN domain --------")
    eval_tgt(tgt_encoder, clf, svhn_data_loader_eval)
    print("-------- MNIST adaption --------")
    eval_tgt(tgt_encoder, clf, mnist_data_loader_eval)
def main():
    # init random seed
    init_random_seed(params.manual_seed)

    # Load dataset
    mnist_data_loader = get_mnist(train=True, download=True)
    mnist_data_loader_eval = get_mnist(train=False, download=True)
    usps_data_loader = get_usps(train=True, download=True)
    usps_data_loader_eval = get_usps(train=False, download=True)

    # Model init Revgard
    tgt_encoder = model_init(Encoder(), params.tgt_encoder_revgrad_path)
    critic = model_init(Discriminator(), params.disc_revgard_path)
    clf = model_init(Classifier(), params.clf_revgrad_path)

    # Train models
    print("====== Training source encoder and classifier in MNIST and USPS domains ======")
    if not (tgt_encoder.pretrained and clf.pretrained and critic.pretrained and params.model_trained):
        tgt_encoder, clf, critic = train_revgrad(tgt_encoder, clf, critic,
                                                 mnist_data_loader, usps_data_loader, robust=False)

    # Eval target encoder on test set of target dataset
    print("====== Evaluating classifier for encoded MNIST and USPS domain ======")
    print("-------- MNIST domain --------")
    eval_tgt(tgt_encoder, clf, mnist_data_loader_eval)
    print("-------- USPS adaption --------")
    eval_tgt(tgt_encoder, clf, usps_data_loader_eval)
def main():
    # init random seed
    init_random_seed(params.manual_seed)

    # Load dataset
    mnist_data_loader = get_mnist(train=True, download=True)
    mnist_data_loader_eval = get_mnist(train=False, download=True)
    usps_data_loader = get_usps(train=True, download=True)
    usps_data_loader_eval = get_usps(train=False, download=True)

    # Model init WDGRL
    tgt_encoder = model_init(Encoder(), params.encoder_wdgrl_rb_path)
    critic = model_init(Discriminator(), params.disc_wdgrl_rb_path)
    clf = model_init(Classifier(), params.clf_wdgrl_rb_path)

    # Train target encoder
    print("====== Robust Training encoder for both MNIST and USPS domains ======")
    if not (tgt_encoder.pretrained and clf.pretrained and params.model_trained):
        tgt_encoder, clf = train_tgt_wdgrl(tgt_encoder, clf, critic,
                                           mnist_data_loader, usps_data_loader, usps_data_loader_eval, robust=True)

    # Eval target encoder on test set of target dataset
    print("====== Evaluating classifier for encoded MNIST and USPS domains ======")
    print("-------- MNIST domain --------")
    eval_tgt_robust(tgt_encoder, clf, mnist_data_loader_eval)
    print("-------- USPS adaption --------")
    eval_tgt_robust(tgt_encoder, clf, usps_data_loader_eval)
Exemple #5
0
def main():
    # init random seed
    init_random_seed(params.manual_seed)

    # Load dataset
    svhn_data_loader = get_svhn(split='train', download=True)
    svhn_data_loader_eval = get_svhn(split='test', download=True)
    mnist_data_loader = get_mnist(train=True, download=True)
    mnist_data_loader_eval = get_mnist(train=False, download=True)

    # Model init DANN
    tgt_encoder = model_init(Encoder(), params.tgt_encoder_dann_rb_path)
    critic = model_init(
        Discriminator(in_dims=params.d_in_dims,
                      h_dims=params.d_h_dims,
                      out_dims=params.d_out_dims), params.disc_dann_rb_path)
    clf = model_init(Classifier(), params.clf_dann_rb_path)

    # Train models
    print(
        "====== Training source encoder and classifier in SVHN and MNIST domains ======"
    )
    if not (tgt_encoder.pretrained and clf.pretrained and critic.pretrained
            and params.model_trained):
        tgt_encoder, clf, critic = train_dann(tgt_encoder,
                                              clf,
                                              critic,
                                              svhn_data_loader,
                                              mnist_data_loader,
                                              mnist_data_loader_eval,
                                              robust=True)

    # Eval target encoder on test set of target dataset
    print(
        "====== Evaluating classifier for encoded SVHN and MNIST domains ======"
    )
    print("-------- SVHN domain --------")
    eval_tgt_robust(tgt_encoder, clf, svhn_data_loader_eval)
    print("-------- MNIST adaption --------")
    eval_tgt_robust(tgt_encoder, clf, mnist_data_loader_eval)
Exemple #6
0
    def __init__(self, name, **kwargs):
        super(PyTorchModelWrapper, self).__init__()
        # framework parameters
        self.device = kwargs.get('device', 'cuda')
        self.dim = kwargs.get('dim', 200)
        self.edge_index = kwargs['ei']
        self.edge_type = kwargs['et']
        ent_sizes = kwargs.get('ent_sizes', [15000, 15000])
        rel_sizes = kwargs.get('rel_sizes', [500, 500])
        dim = self.dim
        self.ent_num = ent_sizes[0] + ent_sizes[1]
        self.rel_num = rel_sizes[0] + rel_sizes[1]
        self.ent_split = ent_sizes[0]
        self.rel_split = rel_sizes[0]
        # load model specific parameters
        param = PARAMS.get(name, {})

        self.hiddens = param.get('hiddens', (dim, dim, dim))
        self.dim = dim = self.hiddens[0]
        self.heads = param.get('heads', (1, 1, 1))
        self.feat_drop = param.get('feat_drop', 0.2)
        self.attn_drop = param.get('attn_drop', 0.)
        self.negative_slope = param.get('negative_slope', 0.)
        self.update = param.get('update', 10)
        self.dist = param.get('dist', 'manhattan')
        self.lr = param.get('lr', 0.005)
        self.share = param.get('share', False)
        sampling = param.get('sampling', ['N'])
        k = param.get('k', [25])
        margin = param.get('margin', [1])
        alpha = param.get('alpha', [1])

        encoder_name = param.get('encoder', None)
        decoder_names = param.get('decoder', ['align'])

        self.encoder = Encoder(name, self.hiddens, self.heads, F.elu,
                               self.feat_drop, self.attn_drop,
                               self.negative_slope,
                               False) if encoder_name else None

        knowledge_decoder = []
        for idx, decoder_name in enumerate(decoder_names):
            knowledge_decoder.append(
                Decoder(
                    decoder_name,
                    params={
                        "e_num": self.ent_num,
                        "r_num": self.rel_num,
                        "dim": self.hiddens[-1],
                        "feat_drop": self.feat_drop,
                        "train_dist": self.dist,
                        "sampling": sampling[idx],
                        "k": k[idx],
                        "margin": margin[idx],
                        "alpha": alpha[idx],
                        "boot": False,
                        # pass other useful parameters to Decoder
                    }))
        self.knowledge_decoder = nn.ModuleList(knowledge_decoder)

        self.cached_sample = {}
        self.preprocessing()
        self.init_emb(encoder_name, decoder_names, margin, self.ent_num,
                      self.rel_num, self.device)
Exemple #7
0
class PyTorchModelWrapper(nn.Module):
    def __init__(self, name, **kwargs):
        super(PyTorchModelWrapper, self).__init__()
        # framework parameters
        self.device = kwargs.get('device', 'cuda')
        self.dim = kwargs.get('dim', 200)
        self.edge_index = kwargs['ei']
        self.edge_type = kwargs['et']
        ent_sizes = kwargs.get('ent_sizes', [15000, 15000])
        rel_sizes = kwargs.get('rel_sizes', [500, 500])
        dim = self.dim
        self.ent_num = ent_sizes[0] + ent_sizes[1]
        self.rel_num = rel_sizes[0] + rel_sizes[1]
        self.ent_split = ent_sizes[0]
        self.rel_split = rel_sizes[0]
        # load model specific parameters
        param = PARAMS.get(name, {})

        self.hiddens = param.get('hiddens', (dim, dim, dim))
        self.dim = dim = self.hiddens[0]
        self.heads = param.get('heads', (1, 1, 1))
        self.feat_drop = param.get('feat_drop', 0.2)
        self.attn_drop = param.get('attn_drop', 0.)
        self.negative_slope = param.get('negative_slope', 0.)
        self.update = param.get('update', 10)
        self.dist = param.get('dist', 'manhattan')
        self.lr = param.get('lr', 0.005)
        self.share = param.get('share', False)
        sampling = param.get('sampling', ['N'])
        k = param.get('k', [25])
        margin = param.get('margin', [1])
        alpha = param.get('alpha', [1])

        encoder_name = param.get('encoder', None)
        decoder_names = param.get('decoder', ['align'])

        self.encoder = Encoder(name, self.hiddens, self.heads, F.elu,
                               self.feat_drop, self.attn_drop,
                               self.negative_slope,
                               False) if encoder_name else None

        knowledge_decoder = []
        for idx, decoder_name in enumerate(decoder_names):
            knowledge_decoder.append(
                Decoder(
                    decoder_name,
                    params={
                        "e_num": self.ent_num,
                        "r_num": self.rel_num,
                        "dim": self.hiddens[-1],
                        "feat_drop": self.feat_drop,
                        "train_dist": self.dist,
                        "sampling": sampling[idx],
                        "k": k[idx],
                        "margin": margin[idx],
                        "alpha": alpha[idx],
                        "boot": False,
                        # pass other useful parameters to Decoder
                    }))
        self.knowledge_decoder = nn.ModuleList(knowledge_decoder)

        self.cached_sample = {}
        self.preprocessing()
        self.init_emb(encoder_name, decoder_names, margin, self.ent_num,
                      self.rel_num, self.device)

    @torch.no_grad()
    def preprocessing(self):
        edge_index0, edge_index1 = self.edge_index
        edge_index1 = edge_index1 + self.ent_split
        self.edge_index = torch.cat([edge_index0, edge_index1], dim=1)

        rel0, rel1 = self.edge_type
        rel1 = rel1 + self.rel_split
        self.rel = torch.cat([rel0, rel1], dim=0)

        total = self.edge_index.size(1)
        ei, et = apply(lambda x: x.cpu().numpy(), self.edge_index, self.rel)
        self.triples = [(ei[0][i], et[i], ei[1][i]) for i in range(total)]
        self.edge_index = add_self_loops(
            remove_self_loops(self.edge_index)[0])[0].t()
        self.ids = [
            set(range(0, self.ent_split)),
            set(range(self.ent_split, self.ent_num))
        ]

    def init_emb(self, encoder_name, decoder_names, margin, ent_num, rel_num,
                 device):
        e_scale, r_scale = 1, 1
        if not encoder_name:
            if decoder_names == ["rotate"]:
                r_scale = r_scale / 2
            elif decoder_names == ["hake"]:
                r_scale = (r_scale / 2) * 3
            elif decoder_names == ["transh"]:
                r_scale = r_scale * 2
            elif decoder_names == ["transr"]:
                r_scale = self.hiddens[0] + 1
        self.ent_embeds = nn.Embedding(ent_num,
                                       self.hiddens[0] * e_scale).to(device)
        self.rel_embeds = nn.Embedding(rel_num, int(self.hiddens[0] *
                                                    r_scale)).to(device)
        if decoder_names == ["rotate"] or decoder_names == ["hake"]:
            ins_range = (margin[0] + 2.0) / float(self.hiddens[0] * e_scale)
            nn.init.uniform_(tensor=self.ent_embeds.weight,
                             a=-ins_range,
                             b=ins_range)
            rel_range = (margin[0] + 2.0) / float(self.hiddens[0] * r_scale)
            nn.init.uniform_(tensor=self.rel_embeds.weight,
                             a=-rel_range,
                             b=rel_range)
            if decoder_names == ["hake"]:
                r_dim = int(self.hiddens[0] / 2)
                nn.init.ones_(tensor=self.rel_embeds.weight[:,
                                                            r_dim:2 * r_dim])
                nn.init.zeros_(tensor=self.rel_embeds.weight[:, 2 * r_dim:3 *
                                                             r_dim])
        else:
            nn.init.xavier_normal_(self.ent_embeds.weight)
            nn.init.xavier_normal_(self.rel_embeds.weight)
        if "alignea" in decoder_names or "mtranse_align" in decoder_names or "transedge" in decoder_names:
            self.ent_embeds.weight.data = F.normalize(self.ent_embeds.weight,
                                                      p=2,
                                                      dim=1)
            self.rel_embeds.weight.data = F.normalize(self.rel_embeds.weight,
                                                      p=2,
                                                      dim=1)
        # elif "transr" in decoder_names:
        #     assert self.args.pre != ""
        #     self.ent_embeds.weight.data = torch.from_numpy(np.load(self.args.pre + "_ins.npy")).to(device)
        #     self.rel_embeds.weight[:, :self.hiddens[0]].data = torch.from_numpy(
        #         np.load(self.args.pre + "_rel.npy")).to(device)
        self.enh_ins_emb = self.ent_embeds.weight.cpu().detach().numpy()
        self.mapping_ins_emb = None

    def run_test(self, pair):
        npy_embeds = apply(lambda x: x.cpu().numpy(), *self.get_embed())
        npy_sim = sim(*npy_embeds, metric=self.dist, normalize=True, csls_k=10)
        evaluate_sim_matrix(pair, torch.from_numpy(npy_sim).to(self.device))

    def get_embed(self):
        emb = self.ent_embeds.weight
        if self.encoder:
            self.encoder.eval()
            emb = self.encoder.forward(self.edge_index, emb, None)
        embs = apply(norm_embed, *[emb[:self.ent_split], emb[self.ent_split:]])
        return embs

    def refresh_cache(self):
        self.cached_sample = {}

    def share_triples(self, pairs, triples):
        ill = {k: v for k, v in pairs}
        new_triple = []
        for (h, r, t) in triples:
            if h in ill:
                h = ill[h]
            if t in ill:
                t = ill[t]
            new_triple.append((h, r, t))
        return list(set(new_triple))

    def gen_sparse_graph_from_triples(self, triples, ins_num, with_r=False):
        edge_dict = {}
        for (h, r, t) in triples:
            if h != t:
                if (h, t) not in edge_dict:
                    edge_dict[(h, t)] = []
                    edge_dict[(t, h)] = []
                edge_dict[(h, t)].append(r)
                edge_dict[(t, h)].append(-r)
        if with_r:
            edges = [[h, t] for (h, t) in edge_dict for r in edge_dict[(h, t)]]
            values = [1 for (h, t) in edge_dict for r in edge_dict[(h, t)]]
            r_ij = [abs(r) for (h, t) in edge_dict for r in edge_dict[(h, t)]]
            edges = np.array(edges, dtype=np.int32)
            values = np.array(values, dtype=np.float32)
            r_ij = np.array(r_ij, dtype=np.float32)
            return edges, values, r_ij
        else:
            edges = [[h, t] for (h, t) in edge_dict]
            values = [1 for (h, t) in edge_dict]
        # add self-loop
        edges += [[e, e] for e in range(ins_num)]
        values += [1 for e in range(ins_num)]
        edges = np.array(edges, dtype=np.int32)
        values = np.array(values, dtype=np.float32)
        return edges, values, None

    def train1step(self, it, pairs, opt):
        pairs = torch.stack([pairs[0], pairs[1] + self.ent_split])
        pairs = pairs.t().cpu().numpy()
        triples = self.triples
        ei, et = self.edge_index, self.edge_type
        if self.share:
            triples = self.share_triples(pairs, triples)
            ei = self.gen_sparse_graph_from_triples(triples, self.ent_num)

        for decoder in self.knowledge_decoder:
            self._train(it, opt, self.encoder, decoder, ei, triples, pairs,
                        self.ids, self.ent_embeds.weight,
                        self.rel_embeds.weight)

    def _train(self, it, opt, encoder, decoder, edges, triples, ills, ids,
               ins_emb, rel_emb):
        device = self.device
        if encoder:
            encoder.train()
        decoder.train()
        losses = []
        if "pos_" + decoder.print_name not in self.cached_sample or it % self.update == 0:
            if decoder.name in ["align", "mtranse_align", "n_r_align"]:
                self.cached_sample["pos_" + decoder.print_name] = ills.tolist()
                self.cached_sample["pos_" + decoder.print_name] = np.array(
                    self.cached_sample["pos_" + decoder.print_name])
            else:
                self.cached_sample["pos_" + decoder.print_name] = triples
            np.random.shuffle(self.cached_sample["pos_" + decoder.print_name])
            # print("train size:", len(self.cached_sample["pos_"+decoder.print_name]))

        train = self.cached_sample["pos_" + decoder.print_name]
        train_batch_size = len(train)
        for i in range(0, len(train), train_batch_size):
            pos_batch = train[i:i + train_batch_size]

            if (decoder.print_name + str(i) not in self.cached_sample
                    or it % self.update == 0) and decoder.sampling_method:
                self.cached_sample[decoder.print_name +
                                   str(i)] = decoder.sampling_method(
                                       pos_batch,
                                       triples,
                                       ills,
                                       ids,
                                       decoder.k,
                                       params={
                                           "emb": self.enh_ins_emb,
                                           "metric": self.dist,
                                       })

            if decoder.sampling_method:
                neg_batch = self.cached_sample[decoder.print_name + str(i)]

            opt.zero_grad()
            if decoder.sampling_method:
                neg = torch.LongTensor(neg_batch).to(device)
                if neg.size(0) > len(pos_batch) * decoder.k:
                    pos = torch.LongTensor(pos_batch).repeat(decoder.k * 2,
                                                             1).to(device)
                elif hasattr(decoder.func, "loss") and decoder.name not in [
                        "rotate", "hake", "conve", "mmea", "n_transe"
                ]:
                    pos = torch.LongTensor(pos_batch).to(device)
                else:
                    pos = torch.LongTensor(pos_batch).repeat(decoder.k,
                                                             1).to(device)
            else:
                pos = torch.LongTensor(pos_batch).to(device)

            if encoder:
                enh_emb = encoder.forward(edges, ins_emb, None)
            else:
                enh_emb = ins_emb

            self.enh_ins_emb = enh_emb[0].cpu().detach().numpy(
            ) if encoder and encoder.name == "naea" else enh_emb.cpu().detach(
            ).numpy()
            if decoder.name == "n_r_align":
                rel_emb = ins_emb

            if decoder.sampling_method:
                pos_score = decoder.forward(enh_emb, rel_emb, pos)
                neg_score = decoder.forward(enh_emb, rel_emb, neg)
                target = torch.ones(neg_score.size()).to(device)

                loss = decoder.loss(pos_score, neg_score,
                                    target) * decoder.alpha
            else:
                loss = decoder.forward(enh_emb, rel_emb, pos) * decoder.alpha

            loss.backward()

            opt.step()
            losses.append(loss.item())

        return np.mean(losses)
Exemple #8
0
def main():
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    if checkpoint is None:
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout)
        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=encoder_lr) if fine_tune_encoder else None

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    # Move to GPU, if available
    # if torch.cuda.device_count() > 1:
    #
    #     decoder = nn.DataParallel(decoder,device_ids=[0,1])
    #     encoder = nn.DataParallel(encoder,device_ids=[0,1])

    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # summary(encoder,(3,256,256))

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'TRAIN',
        transform=transforms.Compose([normalize])),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(CaptionDataset(
        data_folder,
        data_name,
        'VAL',
        transform=transforms.Compose([normalize])),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=workers,
                                             pin_memory=True)

    for epoch in range(start_epoch, epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)

        # One epoch's validation
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(data_name, epoch, epochs_since_improvement, encoder,
                        decoder, encoder_optimizer, decoder_optimizer,
                        recent_bleu4, is_best)
def main():
    # init random seed
    init_random_seed(params.manual_seed)

    # Load dataset
    mnist_data_loader = get_mnist(train=True, download=True)
    mnist_data_loader_eval = get_mnist(train=False, download=True)
    usps_data_loader = get_usps(train=True, download=True)
    usps_data_loader_eval = get_usps(train=False, download=True)

    # Model init DANN
    tgt_encoder = model_init(Encoder(), params.tgt_encoder_dann_rb_path)
    critic = model_init(Discriminator(), params.disc_dann_rb_path)
    clf = model_init(Classifier(), params.clf_dann_rb_path)

    # Train models
    print(
        "====== Robust Training source encoder and classifier in MNIST and USPS domains ======"
    )
    if not (tgt_encoder.pretrained and clf.pretrained and critic.pretrained
            and params.model_trained):
        tgt_encoder, clf, critic = train_dann(tgt_encoder,
                                              clf,
                                              critic,
                                              mnist_data_loader,
                                              usps_data_loader,
                                              usps_data_loader_eval,
                                              robust=False)

    # Eval target encoder on test set of target dataset
    print(
        "====== Evaluating classifier for encoded MNIST and USPS domains ======"
    )
    print("-------- MNIST domain --------")
    eval_tgt_robust(tgt_encoder, clf, critic, mnist_data_loader_eval)
    print("-------- USPS adaption --------")
    eval_tgt_robust(tgt_encoder, clf, critic, usps_data_loader_eval)

    print("====== Pseudo labeling on USPS domain ======")
    pseudo_label(tgt_encoder, clf, "usps_train_pseudo", usps_data_loader)

    # Init a new model
    tgt_encoder = model_init(Encoder(), params.tgt_encoder_path)
    clf = model_init(Classifier(), params.clf_path)

    # Load pseudo labeled dataset
    usps_pseudo_loader = get_usps(train=True, download=True, get_pseudo=True)

    print("====== Standard training on USPS domain with pseudo labels ======")
    if not (tgt_encoder.pretrained and clf.pretrained):
        train_src_adda(tgt_encoder, clf, usps_pseudo_loader, mode='ADV')
    print("====== Evaluating on USPS domain with real labels ======")
    eval_tgt(tgt_encoder, clf, usps_data_loader_eval)

    tgt_encoder = model_init(Encoder(), params.tgt_encoder_rb_path)
    clf = model_init(Classifier(), params.clf_rb_path)
    print("====== Robust training on USPS domain with pseudo labels ======")
    if not (tgt_encoder.pretrained and clf.pretrained):
        train_src_robust(tgt_encoder, clf, usps_pseudo_loader, mode='ADV')
    print("====== Evaluating on USPS domain with real labels ======")
    eval_tgt(tgt_encoder, clf, usps_data_loader_eval)
Exemple #10
0
        exit(0)

    if mode == 'encode':
        '''
      Get Encodings for each author's message from the Transformer based encoders
    '''
        weight_path = os.path.join(
            weight_path, 'bestmodelo_split_{}_1.pt'.format(language[:2]))

        if os.path.isfile(weight_path) == False:
            print(
                f"{bcolors.FAIL}{bcolors.BOLD}ERROR: Weight path set unproperly{bcolors.ENDC}"
            )
            exit(1)

        model = Encoder(interm_layer_size, max_length, language, mode_weigth)
        model.load(weight_path)
        if language[-1] == '_':
            model.transformer.load_adapter("logs/hate_adpt_{}".format(
                language[:2].lower()))

        tweets, _ = load_data_PAN(
            os.path.join(data_path, language[:2].lower()), False)
        preds = []
        encs = []
        batch_size = 200
        for i in tweets:
            e, l = model.get_encodings(i, batch_size)
            encs.append(e)
            preds.append(l)
        torch.save(np.array(encs),