Ejemplo n.º 1
0
    def _init_optimizers(self):
        id_params = list(self.id_encoder.parameters())
        style_params = list(self.style_encoder.parameters())
        dis_params = list(self.discriminator.parameters())

        if self.config.optimizer_name == 'sgd':
            self.gen_optimizer = optim.SGD(id_params + style_params,
                                           lr=1e-3,
                                           weight_decay=0.0005,
                                           momentum=0.9)
            self.dis_optimizer = optim.SGD(dis_params,
                                           lr=1e-3,
                                           weight_decay=0.0005,
                                           momentum=0.9)
        else:
            self.gen_optimizer = optim.Adam(id_params + style_params,
                                            lr=1e-3,
                                            betas=[0.9, 0.999],
                                            weight_decay=5e-4)
            self.dis_optimizer = optim.Adam(dis_params,
                                            lr=1e-3,
                                            betas=[0.9, 0.999],
                                            weight_decay=5e-4)

        self.gen_lr_scheduler = WarmupMultiStepLR(self.gen_optimizer, [40, 70],
                                                  0.1, 0.01, 10, 'linear')
        self.dis_lr_scheduler = WarmupMultiStepLR(self.dis_optimizer, [40, 70],
                                                  0.1, 0.01, 10, 'linear')
Ejemplo n.º 2
0
 def test_something(self):
     net = nn.Linear(10, 10)
     optimizer = make_optimizer(cfg, net)
     lr_scheduler = WarmupMultiStepLR(optimizer, [20, 40], warmup_iters=10)
     for i in range(50):
         lr_scheduler.step()
         for j in range(3):
             print(i, lr_scheduler.get_lr()[0])
             optimizer.step()
Ejemplo n.º 3
0
def make_lr_scheduler(cfg, optimizer):
    return WarmupMultiStepLR(
        optimizer,
        cfg.TRAIN.steps,
        cfg.TRAIN.gamma,
        warmup_factor=cfg.TRAIN.warmup_factor,
        warmup_iters=cfg.TRAIN.warmup_iters,
        warmup_method=cfg.TRAIN.warmup_method,
    )
Ejemplo n.º 4
0
def train(cfg):
    # logger
    logger = logging.getLogger(name="merlin.baseline.train")
    logger.info("training...")

    # transform
    transform_train_list = [
        # transforms.RandomResizedCrop(size=128, scale=(0.75,1.0), ratio=(0.75,1.3333), interpolation=3), #Image.BICUBIC)
        transforms.Resize(size=cfg.INPUT.SIZE_TRAIN, interpolation=1),
        transforms.Pad(32),
        transforms.RandomCrop(cfg.INPUT.SIZE_TRAIN),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]

    transform_val_list = [
        transforms.Resize(size=cfg.INPUT.SIZE_TEST, interpolation=3),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]

    # prepare dataset
    train_dataset = MyDataset(root=cfg.DATA.ROOT, transform=transforms.Compose(transform_train_list), type='train')
    val_dataset = MyDataset(root=cfg.DATA.ROOT, transform=transforms.Compose(transform_val_list), type='val')
    train_loader = DataLoader(train_dataset,
                              batch_size=cfg.SOLVER.BATCH_SIZE,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=False)
    val_loader = DataLoader(val_dataset,
                              batch_size=cfg.SOLVER.BATCH_SIZE,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=False)
    num_classes = cfg.MODEL.HEADS.NUM_CLASSES

    # prepare model
    model = build_model(cfg, num_classes)
    model = model.cuda()
    model = nn.DataParallel(model)

    # prepare solver
    optimizer = make_optimizer(cfg, model)
    scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
                                  cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)

    start_epoch = 0

    # Train and val
    since = time.time()
    for epoch in range(start_epoch, cfg.SOLVER.MAX_EPOCHS):
        model.train(True)
        logger.info("Epoch {}/{}".format(epoch, cfg.SOLVER.MAX_EPOCHS - 1))
        logger.info('-' * 10)

        running_loss = 0.0
        # Iterate over data
        it = 0
        running_acc = 0
        for data in train_loader:
            it += 1
            # get the inputs
            inputs, labels = data
            now_batch_size, c, h, w = inputs.shape
            if now_batch_size < cfg.SOLVER.BATCH_SIZE:  # skip the last batch
                continue

            # wrap them in Variable
            inputs = Variable(inputs.cuda().detach())
            labels = Variable(labels.cuda().detach())

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward
            out = model(inputs)
            loss_dict = get_loss(cfg, outs=out, label=labels)
            loss = sum(loss_dict.values())

            loss.backward()
            optimizer.step()
            scheduler.step()

            # statistics
            with torch.no_grad():
                _, preds = torch.max(out['pred_class_logits'], 1)
                running_loss += loss
                running_acc += torch.sum(preds == labels.data).float().item() / cfg.SOLVER.BATCH_SIZE

            if it % 50 == 0:
                logger.info(
                    'epoch {}, iter {}, loss: {:.3f}, acc: {:.3f}, lr: {:.5f}'.format(
                        epoch, it, running_loss / it, running_acc / it,
                        optimizer.param_groups[0]['lr']))

        epoch_loss = running_loss / it
        epoch_acc = running_acc / it

        logger.info('epoch {} loss: {:.4f} Acc: {:.4f}'.format(epoch, epoch_loss, epoch_acc))

        # save checkpoint
        if epoch % cfg.SOLVER.CHECKPOINT_PERIOD == 0:
            checkpoint = {'epoch': epoch + 1,
                          'model': model.module.state_dict() if (len(cfg.MODEL.DEVICE_ID) - 2) > 1 else model.state_dict(),
                          'optimizer': optimizer.state_dict()
                          }
            save_checkpoint(checkpoint, epoch, cfg)

        # evaluate
        if epoch % cfg.SOLVER.EVAL_PERIOD == 0:
            logger.info('evaluate...')
            model.train(False)

            total = 0.0
            correct = 0.0
            for data in val_loader:
                inputs, labels = data
                inputs = Variable(inputs.cuda().detach())
                labels = Variable(labels.cuda().detach())
                with torch.no_grad():
                    out = model(inputs)
                    _, preds = torch.max(out['pred_class_logits'], 1)
                    c = (preds == labels).squeeze()
                    total += c.size(0)
                    correct += c.float().sum().item()
            acc = correct / total
            logger.info('eval acc:{:.4f}'.format(acc))

        time_elapsed = time.time() - since
        logger.info('Training complete in {:.0f}m {:.0f}s\n'.format(
            time_elapsed // 60, time_elapsed % 60))

    return model
Ejemplo n.º 5
0
class DGN(object):
    """
    The model which incorporates identity shuffing and reconstruction loss
    """
    def __init__(self, num_classes, config):
        self.config = config
        self.num_classes = num_classes
        self.device = torch.device('cuda')
        self._init_networks()
        self._init_optimizers()
        self._init_criterion()

    def _init_networks(self):
        #init models
        self.id_encoder = ID_encoder(self.num_classes,
                                     self.config).to(self.device)
        self.style_encoder = Style_encoder(self.config).to(self.device)
        self.decoder = F_Decoder(2, 2, self.style_encoder.output_dim, 3, 0,
                                 'adain', 'relu', 'reflect').to(self.device)
        self.discriminator = Discriminator(n_layer=4,
                                           middle_dim=32,
                                           num_scales=2).to(self.device)
        self.discriminator.apply(weights_init('gaussian'))
        self.mlp = MLP(2048,
                       self.get_num_adain_params(self.decoder),
                       256,
                       3,
                       norm='none',
                       activ='relu').to(self.device)

        self.model_list = []
        self.model_list.append(self.id_encoder)
        self.model_list.append(self.style_encoder)
        self.model_list.append(self.discriminator)

    def _init_optimizers(self):
        id_params = list(self.id_encoder.parameters())
        style_params = list(self.style_encoder.parameters())
        dis_params = list(self.discriminator.parameters())

        if self.config.optimizer_name == 'sgd':
            self.gen_optimizer = optim.SGD(id_params + style_params,
                                           lr=1e-3,
                                           weight_decay=0.0005,
                                           momentum=0.9)
            self.dis_optimizer = optim.SGD(dis_params,
                                           lr=1e-3,
                                           weight_decay=0.0005,
                                           momentum=0.9)
        else:
            self.gen_optimizer = optim.Adam(id_params + style_params,
                                            lr=1e-3,
                                            betas=[0.9, 0.999],
                                            weight_decay=5e-4)
            self.dis_optimizer = optim.Adam(dis_params,
                                            lr=1e-3,
                                            betas=[0.9, 0.999],
                                            weight_decay=5e-4)

        self.gen_lr_scheduler = WarmupMultiStepLR(self.gen_optimizer, [40, 70],
                                                  0.1, 0.01, 10, 'linear')
        self.dis_lr_scheduler = WarmupMultiStepLR(self.dis_optimizer, [40, 70],
                                                  0.1, 0.01, 10, 'linear')

    def _init_criterion(self):
        self.id_loss = nn.CrossEntropyLoss()
        self.reconst_loss = nn.L1Loss()
        self.triplet_loss = TripletLoss(0.5)

    def lr_scheduler_step(self):
        self.gen_lr_scheduler.step()
        self.dis_lr_scheduler.step()

    def encode(self, images):
        # encode  an image to foreground vector and background vector
        id_scores, id_global_feat = self.id_encoder(images)
        style_feature_maps = self.style_encoder(images)
        return id_global_feat, style_feature_maps, id_scores

    def decode(self, id, style):
        adain_params = self.mlp(id)
        self.assign_adain_params(adain_params, self.decoder)
        images = self.decoder(style)
        return images

    def get_num_adain_params(self, model):
        # return the number of AdaIN parameters needed by the model
        num_adain_params = 0
        for m in model.modules():
            if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
                num_adain_params += 2 * m.num_features
        return num_adain_params

    def assign_adain_params(self, adain_params, model):
        # assign the adain_params to the AdaIN layers in model
        for m in model.modules():
            if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
                mean = adain_params[:, :m.num_features]
                std = adain_params[:, m.num_features:2 * m.num_features]
                m.bias = mean.contiguous().view(-1)
                m.weight = std.contiguous().view(-1)
                if adain_params.size(1) > 2 * m.num_features:
                    adain_params = adain_params[:, 2 * m.num_features:]

    def save_model(self, save_epoch):
        # save model
        for ii, _ in enumerate(self.model_list):
            model_dir_path = self.config.save_models_path + 'models_{}'.format(
                save_epoch)
            if os.path.exists(self.config.save_models_path):
                make_dirs(model_dir_path)
            else:
                make_dirs(self.config.save_models_path)
                make_dirs(model_dir_path)
            torch.save(
                self.model_list[ii].state_dict(),
                os.path.join(model_dir_path,
                             'model-{}_{}.pkl'.format(ii, save_epoch)))

        if self.config.max_save_model_num > 0:
            root, dirs, files = os_walk(self.config.save_models_path)
            total_save_models = len(dirs)
            if total_save_models > self.config.max_save_model_num:
                delet_index = total_save_models - self.config.max_save_model_num
                for to_delet in dirs[:delet_index]:
                    shutil.rmtree(self.config.save_models_path + to_delet)

    def resume_model(self, resume_epoch):
        for i, _ in enumerate(self.model_list):
            self.model_list[i].load_state_dict(
                torch.load(
                    os.path.join(
                        self.config.save_models_path +
                        'models_{}'.format(resume_epoch),
                        'model-{}_{}.pkl'.format(i, resume_epoch))))
        print('Time:{}, successfully resume model from {}'.format(
            time_now(), resume_epoch))

    def resume_model_from_path(self, path, resume_epoch):
        for i, _ in enumerate(self.model_list):
            self.model_list[i].load_state_dict(
                torch.load(
                    os.path.join(path + 'models_{}'.format(resume_epoch)),
                    'model-{}_{}'.format(i, resume_epoch)))

        # set the model into training mode

    def set_train(self):
        for i, _ in enumerate(self.model_list):
            self.model_list[i] = self.model_list[i].train()

    def set_eval(self):
        for i, _ in enumerate(self.model_list):
            self.model_list[i] = self.model_list[i].eval()
Ejemplo n.º 6
0
def train(opt, train_iter, dev_iter, test_iter, syn_data, verbose=True):
    global_start = time.time()
    #logger = utils.getLogger()
    model = models.setup(opt)

    if opt.resume != None:
        model = set_params(model, opt.resume)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if torch.cuda.is_available():
        model.cuda()
        #model=torch.nn.DataParallel(model)

    # set optimizer
    if opt.embd_freeze == True:
        model.embedding.weight.requires_grad = False
    else:
        model.embedding.weight.requires_grad = True
    params = [param for param in model.parameters() if param.requires_grad
              ]  #filter(lambda p: p.requires_grad, model.parameters())
    optimizer = utils.getOptimizer(params,
                                   name=opt.optimizer,
                                   lr=opt.learning_rate,
                                   weight_decay=opt.weight_decay,
                                   scheduler=utils.get_lr_scheduler(
                                       opt.lr_scheduler))
    scheduler = WarmupMultiStepLR(optimizer, (40, 80), 0.1, 1.0 / 10.0, 2,
                                  'linear')

    from label_smooth import LabelSmoothSoftmaxCE
    if opt.label_smooth != 0:
        assert (opt.label_smooth <= 1 and opt.label_smooth > 0)
        loss_fun = LabelSmoothSoftmaxCE(lb_pos=1 - opt.label_smooth,
                                        lb_neg=opt.label_smooth)
    else:
        loss_fun = F.cross_entropy

    filename = None
    acc_adv_list = []
    start = time.time()
    kl_control = 0

    # initialize synonyms with the same embd
    from PWWS.word_level_process import word_process, get_tokenizer
    tokenizer = get_tokenizer(opt)

    if opt.embedding_prep == "same":
        father_dict = {}
        for index in range(1 + len(tokenizer.index_word)):
            father_dict[index] = index

        def get_father(x):
            if father_dict[x] == x:
                return x
            else:
                fa = get_father(father_dict[x])
                father_dict[x] = fa
                return fa

        for index in range(len(syn_data) - 1, 0, -1):
            syn_list = syn_data[index]
            for pos in syn_list:
                fa_pos = get_father(pos)
                fa_anch = get_father(index)
                if fa_pos == fa_anch:
                    father_dict[index] = index
                    father_dict[fa_anch] = index
                else:
                    father_dict[index] = index
                    father_dict[fa_anch] = index
                    father_dict[fa_pos] = index

        print("Same embedding for synonyms as embd prep.")
        set_different_embd = set()
        for key in father_dict:
            fa = get_father(key)
            set_different_embd.add(fa)
            with torch.no_grad():
                model.embedding.weight[key, :] = model.embedding.weight[fa, :]
        print(len(set_different_embd))

    elif opt.embedding_prep == "ge":
        print("Graph embedding as embd prep.")
        ge_file_path = opt.ge_file_path
        f = open(ge_file_path, 'rb')
        saved = pickle.load(f)
        ge_embeddings_dict = saved['walk_embeddings']
        #model = saved['model']
        f.close()
        with torch.no_grad():
            for key in ge_embeddings_dict:
                model.embedding.weight[int(key), :] = torch.FloatTensor(
                    ge_embeddings_dict[key])
    else:
        print("No embd prep.")

    from from_certified.attack_surface import WordSubstitutionAttackSurface, LMConstrainedAttackSurface
    if opt.lm_constraint:
        attack_surface = LMConstrainedAttackSurface.from_files(
            opt.certified_neighbors_file_path, opt.imdb_lm_file_path)
    else:
        attack_surface = WordSubstitutionAttackSurface.from_files(
            opt.certified_neighbors_file_path, opt.imdb_lm_file_path)

    best_adv_acc = 0
    for epoch in range(21):

        if opt.smooth_ce:
            if epoch < 10:
                weight_adv = epoch * 1.0 / 10
                weight_clean = 1 - weight_adv
            else:
                weight_adv = 1
                weight_clean = 0
        else:
            weight_adv = opt.weight_adv
            weight_clean = opt.weight_clean

        if epoch >= opt.kl_start_epoch:
            kl_control = 1

        sum_loss = sum_loss_adv = sum_loss_kl = sum_loss_clean = 0
        total = 0

        for iters, batch in enumerate(train_iter):

            text = batch[0].to(device)
            label = batch[1].to(device)
            anch = batch[2].to(device)
            pos = batch[3].to(device)
            neg = batch[4].to(device)
            anch_valid = batch[5].to(device).unsqueeze(2)
            text_like_syn = batch[6].to(device)
            text_like_syn_valid = batch[7].to(device)

            bs, sent_len = text.shape

            model.train()

            # zero grad
            optimizer.zero_grad()

            if opt.pert_set == "ad_text":
                attack_type_dict = {
                    'num_steps': opt.train_attack_iters,
                    'loss_func': 'ce' if opt.if_ce_adp else 'kl',
                    'w_optm_lr': opt.w_optm_lr,
                    'sparse_weight': opt.attack_sparse_weight,
                    'out_type': "text"
                }
                embd = model(mode="text_to_embd",
                             input=text)  #in bs, len sent, vocab
                n, l, s = text_like_syn.shape
                text_like_syn_embd = model(mode="text_to_embd",
                                           input=text_like_syn.reshape(
                                               n, l * s)).reshape(n, l, s, -1)
                text_adv = model(mode="get_adv_by_convex_syn",
                                 input=embd,
                                 label=label,
                                 text_like_syn_embd=text_like_syn_embd,
                                 text_like_syn_valid=text_like_syn_valid,
                                 text_like_syn=text_like_syn,
                                 attack_type_dict=attack_type_dict)

            elif opt.pert_set == "ad_text_syn_p":
                attack_type_dict = {
                    'num_steps': opt.train_attack_iters,
                    'loss_func': 'ce' if opt.if_ce_adp else 'kl',
                    'w_optm_lr': opt.w_optm_lr,
                    'sparse_weight': opt.train_attack_sparse_weight,
                    'out_type': "comb_p"
                }
                embd = model(mode="text_to_embd",
                             input=text)  #in bs, len sent, vocab
                n, l, s = text_like_syn.shape
                text_like_syn_embd = model(mode="text_to_embd",
                                           input=text_like_syn.reshape(
                                               n, l * s)).reshape(n, l, s, -1)
                adv_comb_p = model(mode="get_adv_by_convex_syn",
                                   input=embd,
                                   label=label,
                                   text_like_syn_embd=text_like_syn_embd,
                                   text_like_syn_valid=text_like_syn_valid,
                                   attack_type_dict=attack_type_dict)

            elif opt.pert_set == "ad_text_hotflip":
                attack_type_dict = {
                    'num_steps': opt.train_attack_iters,
                    'loss_func': 'ce' if opt.if_ce_adp else 'kl',
                }
                text_adv = model(mode="get_adv_hotflip",
                                 input=text,
                                 label=label,
                                 text_like_syn_valid=text_like_syn_valid,
                                 text_like_syn=text_like_syn,
                                 attack_type_dict=attack_type_dict)

            elif opt.pert_set == "l2_ball":
                set_radius = opt.train_attack_eps
                attack_type_dict = {
                    'num_steps': opt.train_attack_iters,
                    'step_size': opt.train_attack_step_size * set_radius,
                    'random_start': opt.random_start,
                    'epsilon': set_radius,
                    #'loss_func': 'ce',
                    'loss_func': 'ce' if opt.if_ce_adp else 'kl',
                    'direction': 'away',
                    'ball_range': opt.l2_ball_range,
                }
                embd = model(mode="text_to_embd",
                             input=text)  #in bs, len sent, vocab
                embd_adv = model(mode="get_embd_adv",
                                 input=embd,
                                 label=label,
                                 attack_type_dict=attack_type_dict)

            optimizer.zero_grad()
            # clean loss
            predicted = model(mode="text_to_logit", input=text)
            loss_clean = loss_fun(predicted, label)
            # adv loss
            if opt.pert_set == "ad_text" or opt.pert_set == "ad_text_hotflip":
                predicted_adv = model(mode="text_to_logit", input=text_adv)
            elif opt.pert_set == "ad_text_syn_p":
                predicted_adv = model(mode="text_syn_p_to_logit",
                                      input=text_like_syn,
                                      comb_p=adv_comb_p)
            elif opt.pert_set == "l2_ball":
                predicted_adv = model(mode="embd_to_logit", input=embd_adv)

            loss_adv = loss_fun(predicted_adv, label)
            # kl loss
            criterion_kl = nn.KLDivLoss(reduction="sum")
            loss_kl = (1.0 / bs) * criterion_kl(
                F.log_softmax(predicted_adv, dim=1), F.softmax(predicted,
                                                               dim=1))

            # optimize
            loss = opt.weight_kl * kl_control * loss_kl + weight_adv * loss_adv + weight_clean * loss_clean
            loss.backward()
            optimizer.step()
            sum_loss += loss.item()
            sum_loss_adv += loss_adv.item()
            sum_loss_clean += loss_clean.item()
            sum_loss_kl += loss_kl.item()
            predicted, idx = torch.max(predicted, 1)
            precision = (idx == label).float().mean().item()
            predicted_adv, idx = torch.max(predicted_adv, 1)
            precision_adv = (idx == label).float().mean().item()
            total += 1

            out_log = "%d epoch %d iters: loss: %.3f, loss_kl: %.3f, loss_adv: %.3f, loss_clean: %.3f | acc: %.3f acc_adv: %.3f | in %.3f seconds" % (
                epoch, iters, sum_loss / total, sum_loss_kl / total,
                sum_loss_adv / total, sum_loss_clean / total, precision,
                precision_adv, time.time() - start)
            start = time.time()
            print(out_log)

        scheduler.step()

        if epoch % 1 == 0:
            acc = utils.imdb_evaluation(opt, device, model, dev_iter)
            out_log = "%d epoch with dev acc %.4f" % (epoch, acc)
            print(out_log)
            adv_acc = utils.imdb_evaluation_ascc_attack(
                opt, device, model, dev_iter, tokenizer)
            out_log = "%d epoch with dev adv acc against ascc attack %.4f" % (
                epoch, adv_acc)
            print(out_log)

            #hotflip_adv_acc=utils.evaluation_hotflip_adv(opt, device, model, dev_iter, tokenizer)
            #out_log="%d epoch with dev hotflip adv acc %.4f" % (epoch,hotflip_adv_acc)
            #logger.info(out_log)
            #print(out_log)

            if adv_acc >= best_adv_acc:
                best_adv_acc = adv_acc
                best_save_dir = os.path.join(opt.out_path,
                                             "{}_best.pth".format(opt.model))
                state = {
                    'net': model.state_dict(),
                    'epoch': epoch,
                }
                torch.save(state, best_save_dir)

    # restore best according to dev set
    model = set_params(model, best_save_dir)
    acc = utils.imdb_evaluation(opt, device, model, test_iter)
    print("test acc %.4f" % (acc))
    adv_acc = utils.imdb_evaluation_ascc_attack(opt, device, model, test_iter,
                                                tokenizer)
    print("test adv acc against ascc attack %.4f" % (adv_acc))
    genetic_attack(opt,
                   device,
                   model,
                   attack_surface,
                   dataset=opt.dataset,
                   genetic_test_num=opt.genetic_test_num)
    fool_text_classifier_pytorch(opt,
                                 device,
                                 model,
                                 dataset=opt.dataset,
                                 clean_samples_cap=opt.pwws_test_num)
Ejemplo n.º 7
0
def make_optimizer(model):
    params = []
    for key, value in model.named_parameters():

        if not value.requires_grad:
            continue

        lr = args.lr
        weight_decay = args.wd

        params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
    optimizer = getattr(torch.optim, 'SGD')(params, momentum=0.9)
    return optimizer

optimizer = make_optimizer(net)
scheduler = WarmupMultiStepLR(optimizer, (args.lr_decay_1, args.lr_decay_2), 0.1, 1.0/3.0, args.warmup_epoch, 'linear')

def api_net_train(args, net, epoch, summary_writer):

    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss_adv = train_loss_adv_api = 0
    correct_adv = correct_adv_api = 0
    total = 0
    for batch_idx, (x_natural, y) in enumerate(trainloader):

        x_natural, y = x_natural.to(device), y.to(device)

        batch_size, ch, h, w = x_natural.shape
        num_classes = len(classes)
Ejemplo n.º 8
0
    model = SSD(cfg)
    if torch.cuda.is_available():
        model = model.cuda()
        # model = torch.nn.DataParallel(model).cuda()
    # else:
    # model = torch.nn.DataParallel(model)

    multiBoxLoss = MultiBoxLoss()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=1e-3,
                                momentum=0.9,
                                weight_decay=5e-4)
    scheduler = WarmupMultiStepLR(
        optimizer=optimizer,
        milestones=[120 * len(dataloader_train), 160 * len(dataloader_train)],
        gamma=0.1,
        warmup_factor=1.0 / 3,
        warmup_iters=500)

    metricLogger = MetricLogger()
    for epoch_num in range(200):
        loss_hist = []
        epoch_loss = []
        for iter_num, data in enumerate(dataloader_train):
            model.train()
            optimizer.zero_grad()
            gt_boxes = data['bbox']
            images = data['image']

            classification, regression, anchors = model(images)
            loss_dict = multiBoxLoss(classification, regression, anchors,