예제 #1
0
    def init_self_ditsllation_models(self):

        input_size = (1, 3, 224, 224)

        noise_input = torch.randn(input_size).to(self.device)
        self.models[0](noise_input)
        trans_input = list(self.models[0].total_feature_maps.values())[-1]
        self.fusion_channel = trans_input.size(1)
        self.fusion_spatil = trans_input.size(2)

        self.sd_models = []
        self.sd_optimizers = []
        self.sd_schedulers = []
        for i in range(1, self.model_num):
            sd_model = SelfDistillationModel(
                input_channel=trans_input.size(1),
                layer_num=len(self.models[0].extract_layers) - 1).to(
                    self.device)
            sd_optimizer = optim.Adam(sd_model.parameters(),
                                      weight_decay=self.opt.weight_decay)
            sd_scheduler = utils.get_scheduler(sd_optimizer, self.opt)
            self.sd_models.append(sd_model)
            self.sd_optimizers.append(sd_optimizer)
            self.sd_schedulers.append(sd_scheduler)

        self.sd_leader_model = SelfDistillationModel(
            input_channel=trans_input.size(1),
            layer_num=len(self.models[0].extract_layers) - 1).to(self.device)
        self.sd_leader_optimizer = optim.Adam(
            self.sd_leader_model.parameters(),
            weight_decay=self.opt.weight_decay)
        self.sd_leader_scheduler = utils.get_scheduler(
            self.sd_leader_optimizer, self.opt)
예제 #2
0
    def __init__(self, opt, logger):

        self.opt = opt
        self.opt.isTrain = True
        self.logger = logger
        self.visualizer = Visualizer(opt)
        self.device = torch.device(
            f'cuda:{opt.gpu_ids[0]}') if torch.cuda.is_available() else 'cpu'

        self.epochs = opt.n_epochs
        self.start_epochs = opt.start_epoch
        self.train_batch_size = self.opt.train_batch_size
        self.temperature = self.opt.temperature

        dataLoader = create_dataLoader(opt)
        self.trainLoader = dataLoader.trainLoader
        self.testLoader = dataLoader.testLoader

        self.criterion_CE = nn.CrossEntropyLoss().to(self.device)
        self.criterion_KL = nn.KLDivLoss(reduction='batchmean').to(self.device)

        self.model_num = opt.model_num
        self.models = []
        self.optimizers = []
        self.schedulers = []
        for i in range(self.model_num):
            model = create_model(opt).to(self.device)
            optimizer = optim.SGD(model.parameters(),
                                  lr=opt.lr,
                                  momentum=opt.momentum,
                                  weight_decay=opt.weight_decay,
                                  nesterov=True)
            scheduler = utils.get_scheduler(optimizer, opt)
            self.models.append(model)
            self.optimizers.append(optimizer)
            self.schedulers.append(scheduler)

        self.init_self_ditsllation_models()
        self.init_fusion_module()

        self.leader_model = create_model(
            self.opt,
            leader=True,
            trans_fusion_info=(self.fusion_channel,
                               self.model_num)).to(self.device)
        self.leader_optimizer = optim.SGD(
            self.leader_model.parameters(),
            lr=opt.lr,
            momentum=opt.momentum,
            weight_decay=self.opt.leader_weight_decay,
            nesterov=True)
        self.leader_scheduler = utils.get_scheduler(self.leader_optimizer, opt)
예제 #3
0
    def init_self_ditsllation_models(self):

        if str.startswith(self.opt.dataset, 'cifar'):
            input_size = (1, 3, 32, 32)
        else:
            input_size = (1, 3, 224, 224)

        noise_input = torch.randn(input_size).to(self.device)
        self.models[0](noise_input)
        trans_input = list(self.models[0].total_feature_maps.values())[-1]
        self.fusion_channel = trans_input.size(1)
        self.fusion_spatil = trans_input.size(2)

        self.sd_models = []
        self.sd_optimizers = []
        self.sd_schedulers = []
        for i in range(1, self.model_num):
            if str.startswith(self.opt.model, 'densenet'):
                sd_model = DIYSelfDistillationModel([456, 312, 168],
                                                    2).to(self.device)
            elif str.startswith(self.opt.model, 'googlenet'):
                sd_model = DIYSelfDistillationModel([1024, 832, 480],
                                                    2).to(self.device)
            else:
                sd_model = SelfDistillationModel(
                    input_channel=trans_input.size(1),
                    layer_num=len(self.models[0].extract_layers) - 1).to(
                        self.device)
            sd_optimizer = optim.Adam(sd_model.parameters(),
                                      weight_decay=self.opt.weight_decay)
            sd_scheduler = utils.get_scheduler(sd_optimizer, self.opt)
            self.sd_models.append(sd_model)
            self.sd_optimizers.append(sd_optimizer)
            self.sd_schedulers.append(sd_scheduler)
        if str.startswith(self.opt.model, 'densenet'):
            self.sd_leader_model = DIYSelfDistillationModel([456, 312, 168],
                                                            2).to(self.device)
        elif str.startswith(self.opt.model, 'googlenet'):
            self.sd_leader_model = DIYSelfDistillationModel([1024, 832, 480],
                                                            2).to(self.device)
        else:
            self.sd_leader_model = SelfDistillationModel(
                input_channel=trans_input.size(1),
                layer_num=len(self.models[0].extract_layers) - 1).to(
                    self.device)
        self.sd_leader_optimizer = optim.Adam(
            self.sd_leader_model.parameters(),
            weight_decay=self.opt.weight_decay)
        self.sd_leader_scheduler = utils.get_scheduler(
            self.sd_leader_optimizer, self.opt)
예제 #4
0
    def __init__(self, opt):
        super(MaskMobilePix2PixModel, self).__init__()

        self.opt = opt
        self.device = torch.device(f'cuda:{opt.gpu_ids[0]}') if len(opt.gpu_ids) > 0 else 'cpu'
        self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake', 'mask_weight']
        self.visual_names = ['real_A', 'fake_B', 'real_B']

        self.netG = MaskMobileResnetGenerator(opt=self.opt, ngf=self.opt.ngf)
        self.netD = NLayerDiscriminator(input_nc=3+3, ndf=128)
        self.init_net()

        self.group_mask_weight_names = []
        self.group_mask_weight_names.append('model.11')
        for i in range(13, 22, 1):
            self.group_mask_weight_names.append('model.%d.conv_block.9' % i)

        self.stop_mask = False

        self.criterionGAN = GANLoss(self.opt.gan_mode).to(self.device)
        self.criterionL1 = nn.L1Loss()

        self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(0.5, 0.999))
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(0.5, 0.999))
        self.optimizers = []
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D)
        self.schedulers = [util.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
예제 #5
0
    def __init__(self, opt):
        super(MaskMobileCycleGANModel, self).__init__()
        self.opt = opt
        self.device = torch.device(f'cuda:{opt.gpu_ids[0]}') if len(
            opt.gpu_ids) > 0 else 'cpu'
        self.loss_names = [
            'D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B',
            'mask_weight'
        ]
        visual_names_A = ['real_A', 'fake_B', 'rec_A', 'idt_B']
        visual_names_B = ['real_B', 'fake_A', 'rec_B', 'idt_A']
        self.visual_names = visual_names_A + visual_names_B

        self.netG_A = MaskMobileResnetGenerator(opt=self.opt, ngf=self.opt.ngf)
        self.netG_B = MaskMobileResnetGenerator(opt=self.opt, ngf=self.opt.ngf)

        self.netD_A = NLayerDiscriminator(ndf=self.opt.ndf)
        self.netD_B = NLayerDiscriminator(ndf=self.opt.ndf)
        self.init_net()

        self.fake_A_pool = ImagePool(50)
        self.fake_B_pool = ImagePool(50)

        self.group_mask_weight_names = []
        self.group_mask_weight_names.append('model.11')
        for i in range(13, 22, 1):
            self.group_mask_weight_names.append('model.%d.conv_block.9' % i)

        self.stop_AtoB_mask = False
        self.stop_BtoA_mask = False

        # define loss functions
        self.criterionGAN = GANLoss(opt.gan_mode).to(self.device)
        self.criterionCycle = nn.L1Loss()
        self.criterionIdt = nn.L1Loss()

        # define optimizers
        self.optimizer_G = torch.optim.Adam(itertools.chain(
            self.netG_A.parameters(), self.netG_B.parameters()),
                                            lr=opt.lr,
                                            betas=(0.5, 0.999))
        self.optimizer_D = torch.optim.Adam(itertools.chain(
            self.netD_A.parameters(), self.netD_B.parameters()),
                                            lr=opt.lr,
                                            betas=(0.5, 0.999))
        self.optimizers = []
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D)
        self.schedulers = [
            util.get_scheduler(optimizer, opt) for optimizer in self.optimizers
        ]
예제 #6
0
    def init_fusion_module(self):

        self.num_classes = 100

        self.fusion_module = FusionModule(self.fusion_channel,
                                          self.num_classes,
                                          self.fusion_spatil,
                                          model_num=self.model_num).to(
                                              self.device)
        self.fusion_optimizer = optim.SGD(self.fusion_module.parameters(),
                                          lr=self.opt.lr,
                                          momentum=self.opt.momentum,
                                          weight_decay=1e-5,
                                          nesterov=True)
        self.fusion_scheduler = utils.get_scheduler(self.fusion_optimizer,
                                                    self.opt)
예제 #7
0
    def __init__(self, opt, filter_cfgs=None, channel_cfgs=None):
        super(Pix2PixModel, self).__init__()

        self.opt = opt
        self.device = torch.device(f'cuda:{opt.gpu_ids[0]}') if len(
            opt.gpu_ids) > 0 else 'cpu'
        self.filter_cfgs = filter_cfgs
        self.channel_cfgs = channel_cfgs
        self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
        self.visual_names = ['real_A', 'fake_B', 'real_B']

        self.netG = UnetGenertor(input_nc=3,
                                 output_nc=3,
                                 num_downs=8,
                                 ngf=opt.ngf,
                                 use_dropout=not opt.no_dropout,
                                 filter_cfgs=filter_cfgs,
                                 channel_cfgs=channel_cfgs)
        self.netD = NLayerDiscriminator(input_nc=3 + 3, ndf=128)
        self.init_net()

        self.teacher_model = None
        if self.opt.lambda_attention_distill > 0:
            print('init attention distill')
            self.init_attention_distill()
        if self.opt.lambda_discriminator_distill > 0:
            print('init discriminator distill')
            self.init_discriminator_distill()

        self.criterionGAN = GANLoss(self.opt.gan_mode).to(self.device)
        self.criterionL1 = nn.L1Loss()

        self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                            lr=opt.lr,
                                            betas=(0.5, 0.999))
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=opt.lr,
                                            betas=(0.5, 0.999))
        self.optimizers = []
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D)
        self.schedulers = [
            util.get_scheduler(optimizer, opt) for optimizer in self.optimizers
        ]
예제 #8
0
    def __init__(self, opt, cfg_AtoB=None, cfg_BtoA=None):
        super(MobileCycleGANModel, self).__init__()
        self.opt = opt
        self.device = torch.device(f'cuda:{opt.gpu_ids[0]}') if len(opt.gpu_ids) > 0 else 'cpu'
        self.cfg_AtoB = cfg_AtoB
        self.cfg_BtoA = cfg_BtoA
        self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
        visual_names_A = ['real_A', 'fake_B', 'rec_A', 'idt_B']
        visual_names_B = ['real_B', 'fake_A', 'rec_B', 'idt_A']
        self.visual_names = visual_names_A + visual_names_B

        self.netG_A = MobileResnetGenerator(opt=self.opt, cfg=cfg_AtoB)
        self.netG_B = MobileResnetGenerator(opt=self.opt, cfg=cfg_BtoA)

        self.netD_A = NLayerDiscriminator()
        self.netD_B = NLayerDiscriminator()
        self.init_net()

        self.fake_A_pool = ImagePool(50)
        self.fake_B_pool = ImagePool(50)

        self.teacher_model = None
        if self.opt.lambda_attention_distill > 0:
            print('init attention distill')
            self.init_attention_distill()
        if self.opt.lambda_discriminator_distill > 0:
            print('init discriminator distill')
            self.init_discriminator_distill()

        # define loss functions
        self.criterionGAN= GANLoss(opt.gan_mode).to(self.device)
        self.criterionCycle = nn.L1Loss()
        self.criterionIdt = nn.L1Loss()

        # define optimizers
        self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
                                            lr=opt.lr, betas=(0.5, 0.999))
        self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()),
                                            lr=opt.lr, betas=(0.5, 0.999))
        self.optimizers = []
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D)
        self.schedulers = [util.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
예제 #9
0
    def __init__(self, opt):
        super(MaskPix2PixModel, self).__init__()

        self.opt = opt
        self.device = torch.device(f'cuda:{opt.gpu_ids[0]}') if len(opt.gpu_ids) > 0 else 'cpu'
        self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake', 'mask_weight']
        self.visual_names = ['real_A', 'fake_B', 'real_B']

        self.netG = MaskUnetGenertor(input_nc=3, output_nc=3, num_downs=8, ngf=opt.ngf, use_dropout=not opt.no_dropout, opt=self.opt)
        self.netD = NLayerDiscriminator(input_nc=3+3, ndf=128)
        self.init_net()

        self.stop_mask = False

        self.criterionGAN = GANLoss(self.opt.gan_mode).to(self.device)
        self.criterionL1 = nn.L1Loss()

        self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(0.5, 0.999))
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(0.5, 0.999))
        self.optimizers = []
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D)
        self.schedulers = [util.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
예제 #10
0
def train(conf, data_category):
    print(json.dumps(conf, indent=4))

    os.environ["CUDA_VISIBLE_DEVICES"] = str(conf['device'])
    device = torch.device(0)

    model_name = conf['model']['name']
    optimizer_name = conf['optimizer']['name']
    data_set = conf['data']['dataset']
    graph = h5py.File(os.path.join('data', data_set, 'all_graph.h5'), 'r')
    scheduler_name = conf['scheduler']['name']
    loss = get_loss(**conf['loss'])
    # data_category = conf['data']['data_category']

    loss.to(device)
    encoder, decoder, support = None, None, None
    if model_name == 'Costnet':
        base_model_name = conf['Base']['name']
        encoder, decoder = preprocessing(base_model_name, conf, loss, graph,
                                         data_category, device, data_set,
                                         optimizer_name, scheduler_name)
    if model_name == 'Metricnet' or model_name == 'GWNET' or model_name == 'Evonet' or model_name == 'STGCN' or model_name == 'DCRNN' or model_name == 'STG2Seq' or model_name == 'Evonet2':
        support = preprocessing_for_metric(
            data_category=data_category,
            dataset=conf['data']['dataset'],
            Normal_Method=conf['data']['Normal_Method'],
            _len=conf['data']['_len'],
            **conf['preprocess'])
    model, trainer = create_model(model_name, loss, conf['model'][model_name],
                                  data_category, device, graph, encoder,
                                  decoder, support)

    optimizer = get_optimizer(optimizer_name, model.parameters(),
                              conf['optimizer'][optimizer_name]['lr'])
    scheduler = get_scheduler(scheduler_name, optimizer,
                              **conf['scheduler'][scheduler_name])
    if torch.cuda.device_count() > 1:
        print("use ", torch.cuda.device_count(), "GPUS")
        model = nn.DataParallel(model)
    else:
        model.to(device)

    save_folder = os.path.join('save', conf['name'],
                               f'{data_set}_{"".join(data_category)}',
                               conf['tag'])
    run_folder = os.path.join('run', conf['name'],
                              f'{data_set}_{"".join(data_category)}',
                              conf['tag'])

    shutil.rmtree(save_folder, ignore_errors=True)
    os.makedirs(save_folder)
    shutil.rmtree(run_folder, ignore_errors=True)
    os.makedirs(run_folder)

    with open(os.path.join(save_folder, 'config.yaml'), 'w+') as _f:
        yaml.safe_dump(conf, _f)

    data_loader, normal = get_data_loader(**conf['data'],
                                          data_category=data_category,
                                          device=device,
                                          model_name=model_name)

    if len(data_category) == 2:
        train_model(model=model,
                    dataloaders=data_loader,
                    trainer=trainer,
                    node_num=conf['node_num'],
                    loss_func=loss,
                    optimizer=optimizer,
                    normal=normal,
                    scheduler=scheduler,
                    folder=save_folder,
                    tensorboard_folder=run_folder,
                    device=device,
                    **conf['train'])
        # test_model(folder = save_folder)
    else:
        train_baseline(model=model,
                       dataloaders=data_loader,
                       trainer=trainer,
                       optimizer=optimizer,
                       normal=normal,
                       scheduler=scheduler,
                       folder=save_folder,
                       tensorboard_folder=run_folder,
                       device=device,
                       **conf['train'])
        test_baseline(folder=save_folder,
                      trainer=trainer,
                      model=model,
                      normal=normal,
                      dataloaders=data_loader,
                      device=device)
예제 #11
0
    def train(self,
              mention_dataset,
              candidate_dataset,
              inbatch=True,
              lr=1e-5,
              batch_size=32,
              random_bsz=100000,
              max_ctxt_len=32,
              max_title_len=50,
              max_desc_len=100,
              traindata_size=1000000,
              model_save_interval=10000,
              grad_acc_step=1,
              max_grad_norm=1.0,
              epochs=1,
              warmup_propotion=0.1,
              fp16=False,
              fp16_opt_level=None,
              parallel=False,
              hard_negative=False,
             ):


        if inbatch:

            optimizer = optim.Adam(self.model.parameters(), lr=lr)
            scheduler = get_scheduler(
                batch_size, grad_acc_step, epochs, warmup_propotion, optimizer, traindata_size)

            if fp16:
                assert fp16_opt_level is not None
                self.model, optimizer = to_fp16(self.model, optimizer, fp16_opt_level)

            if parallel:
                self.model = to_parallel(self.model)

            for e in range(epochs):
                #mention_batch = mention_dataset.batch(batch_size=batch_size, random_bsz=random_bsz, max_ctxt_len=max_ctxt_len)
                dataloader = DataLoader(mention_dataset, batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn_json, num_workers=2)
                bar = tqdm(total=traindata_size)
                #for step, (input_ids, labels) in enumerate(mention_batch):
                for step, (input_ids, labels, lines) in enumerate(dataloader):
                    if self.logger:
                        self.logger.debug("%s step", step)
                        self.logger.debug("%s data in batch", len(input_ids))
                        self.logger.debug("%s unique labels in %s labels", len(set(labels)), len(labels))

                    inputs = pad_sequence([torch.LongTensor(token)
                                          for token in input_ids], padding_value=0).t().to(self.device)
                    input_mask = inputs > 0

                    mention_reps = self.model(inputs, input_mask, is_mention=True)

                    pages = list(labels[:])
                    if hard_negative:
                        for label, line in zip(labels, lines):
                            for i in line["nearest_neighbors"]:
                                if str(i) == label:
                                    break
                                pages.append(str(i))

                    candidate_input_ids = candidate_dataset.get_pages(pages, max_title_len=max_title_len, max_desc_len=max_desc_len)
                    candidate_inputs = pad_sequence([torch.LongTensor(token)
                                                    for token in candidate_input_ids], padding_value=0).t().to(self.device)
                    candidate_mask = candidate_inputs > 0
                    candidate_reps = self.model(candidate_inputs, candidate_mask, is_mention=False)

                    scores = mention_reps.mm(candidate_reps.t())
                    accuracy = self.calculate_inbatch_accuracy(scores)

                    target = torch.LongTensor(torch.arange(scores.size(0))).to(self.device)
                    loss = F.cross_entropy(scores, target, reduction="mean")

                    if self.logger:
                        self.logger.debug("Accurac: %s", accuracy)
                        self.logger.debug("Train loss: %s", loss.item())


                    if fp16:
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()


                    if (step + 1) % grad_acc_step == 0:
                        if fp16:
                            torch.nn.utils.clip_grad_norm_(
                                amp.master_params(optimizer), max_grad_norm
                            )
                        else:
                            torch.nn.utils.clip_grad_norm_(
                                self.model.parameters(), max_grad_norm
                            )
                        optimizer.step()
                        scheduler.step()
                        optimizer.zero_grad()

                        if self.logger:
                            self.logger.debug("Back propagation in step %s", step+1)
                            self.logger.debug("LR: %s", scheduler.get_lr())

                    if self.use_mlflow:
                        mlflow.log_metric("train loss", loss.item(), step=step)
                        mlflow.log_metric("accuracy", accuracy, step=step)

                    if self.model_path is not None and step % model_save_interval == 0:
                        #torch.save(self.model.state_dict(), self.model_path)
                        save_model(self.model, self.model_path)

                    bar.update(len(input_ids))
                    bar.set_description(f"Loss: {loss.item()}, Accuracy: {accuracy}")
예제 #12
0
    def train(
        self,
        mention_dataset,
        candidate_dataset,
        lr=1e-5,
        max_ctxt_len=32,
        max_title_len=50,
        max_desc_len=100,
        traindata_size=1000000,
        model_save_interval=10000,
        grad_acc_step=1,
        max_grad_norm=1.0,
        epochs=1,
        warmup_propotion=0.1,
        fp16=False,
        fp16_opt_level=None,
        parallel=False,
    ):

        optimizer = optim.Adam(self.model.parameters(), lr=lr)
        scheduler = get_scheduler(1, grad_acc_step, epochs, warmup_propotion,
                                  optimizer, traindata_size)

        if fp16:
            assert fp16_opt_level is not None
            self.model, optimizer = to_fp16(self.model, optimizer,
                                            fp16_opt_level)

        if parallel:
            self.model = to_parallel(self.model)

        for e in range(epochs):
            dataloader = DataLoader(mention_dataset,
                                    batch_size=1,
                                    shuffle=True,
                                    collate_fn=my_collate_fn_json,
                                    num_workers=2)
            bar = tqdm(total=traindata_size)
            for step, (input_ids, labels, lines) in enumerate(dataloader):
                if step > traindata_size:
                    break

                if self.logger:
                    self.logger.debug("%s step", step)
                    self.logger.debug("%s data in batch", len(input_ids))
                    self.logger.debug("%s unique labels in %s labels",
                                      len(set(labels)), len(labels))

                pages = list(labels)
                for nn in lines[0]["nearest_neighbors"]:
                    if nn not in pages:
                        pages.append(str(nn))
                candidate_input_ids = candidate_dataset.get_pages(
                    pages,
                    max_title_len=max_title_len,
                    max_desc_len=max_desc_len)

                inputs = self.merge_mention_candidate(input_ids[0],
                                                      candidate_input_ids)

                inputs = pad_sequence(
                    [torch.LongTensor(token) for token in inputs],
                    padding_value=0).t().to(self.device)
                input_mask = inputs > 0
                scores = self.model(inputs, input_mask)

                target = torch.LongTensor([0]).to(self.device)
                loss = F.cross_entropy(scores.unsqueeze(0),
                                       target.unsqueeze(0),
                                       reduction="mean")

                if self.logger:
                    self.logger.debug("Train loss: %s", loss.item())

                if fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                if (step + 1) % grad_acc_step == 0:
                    if fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer), max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                       max_grad_norm)
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()

                    if self.logger:
                        self.logger.debug("Back propagation in step %s",
                                          step + 1)
                        self.logger.debug("LR: %s", scheduler.get_lr())

                if self.use_mlflow:
                    mlflow.log_metric("train loss", loss.item(), step=step)

                if self.model_path is not None and step % model_save_interval == 0:
                    #torch.save(self.model.state_dict(), self.model_path)
                    save_model(self.model, self.model_path)

                bar.update(len(input_ids))
                bar.set_description(f"Loss: {loss.item()}")
예제 #13
0
def preprocessing(base_model_name, conf, loss, graph, data_category, device,
                  data_set, optimizer_name, scheduler_name):

    if base_model_name == 'LinearDecompose':
        data_loader = get_data_loader_base(base_model_name=base_model_name,
                                           dataset=conf['data']['dataset'],
                                           batch_size=conf['batch_size_base'],
                                           _len=conf['data']['_len'],
                                           data_category=data_category,
                                           device=device)
        model, trainer = create_model(base_model_name, loss,
                                      conf['Base'][base_model_name],
                                      data_category, device, graph)
        save_folder = os.path.join('saves',
                                   f"{conf['name']}_{base_model_name}",
                                   f'{data_set}_{"".join(data_category)}')
        run_folder = os.path.join('run', f"{conf['name']}_{base_model_name}",
                                  f'{data_set}_{"".join(data_category)}')
        optimizer = get_optimizer(optimizer_name, model.parameters(),
                                  conf['optimizerbase'][optimizer_name]['lr'])
        scheduler = get_scheduler(scheduler_name, optimizer,
                                  **conf['scheduler'][scheduler_name])
        shutil.rmtree(save_folder, ignore_errors=True)
        os.makedirs(save_folder)
        shutil.rmtree(run_folder, ignore_errors=True)
        os.makedirs(run_folder)
        model = train_decompose(model=model,
                                dataloaders=data_loader,
                                trainer=trainer,
                                optimizer=optimizer,
                                scheduler=scheduler,
                                folder=save_folder,
                                tensorboard_floder=run_folder,
                                device=device,
                                **conf['train'])
        model.load_state_dict(
            torch.load(f"{os.path.join(save_folder, 'best_model.pkl')}")
            ['model_state_dict'])
        return model.encoder, model.decoder
    if base_model_name == 'SvdDecompose':
        data = get_data_loader_base(base_model_name=base_model_name,
                                    dataset=conf['data']['dataset'],
                                    batch_size=conf['batch_size_base'],
                                    _len=conf['data']['_len'],
                                    data_category=data_category,
                                    device=device)
        data = torch.from_numpy(data).float().to(device)
        save_folder = os.path.join('saves',
                                   f"{conf['name']}_{base_model_name}",
                                   f'{data_set}_{"".join(data_category)}')
        run_folder = os.path.join('run', f"{conf['name']}_{base_model_name}",
                                  f'{data_set}_{"".join(data_category)}')
        model, trainer = create_model(base_model_name, loss,
                                      conf['Base'][base_model_name],
                                      data_category, device, graph)
        shutil.rmtree(save_folder, ignore_errors=True)
        os.makedirs(save_folder)
        shutil.rmtree(run_folder, ignore_errors=True)
        os.makedirs(run_folder)
        model.decompose(data)
        return model.encoder, model.decoder
예제 #14
0
def train(conf, data_category):
    print(json.dumps(conf, indent=4))

    os.environ["CUDA_VISIBLE_DEVICES"] = str(conf['device'])
    device = torch.device(0)

    model_name = conf['model']['name']
    optimizer_name = conf['optimizer']['name']
    data_set = conf['data']['dataset']
    scheduler_name = conf['scheduler']['name']
    loss = get_loss(**conf['loss'])

    loss.to(device)


    support = preprocessing_for_metric(data_category=data_category, dataset=conf['data']['dataset'],
                                           Normal_Method=conf['data']['Normal_Method'], _len=conf['data']['_len'], **conf['preprocess'])
    model, trainer = create_model(model_name,
                                  loss,
                                  conf['model'][model_name],
                                  data_category,
                                  device,
                                  support)

    optimizer = get_optimizer(optimizer_name, model.parameters(), conf['optimizer'][optimizer_name]['lr'])
    scheduler = get_scheduler(scheduler_name, optimizer, **conf['scheduler'][scheduler_name])
    if torch.cuda.device_count() > 1:
        print("use ", torch.cuda.device_count(), "GPUS")
        model = nn.DataParallel(model)
    else:
        model.to(device)

    save_folder = os.path.join('save', conf['name'], f'{data_set}_{"".join(data_category)}', conf['tag'])
    run_folder = os.path.join('run', conf['name'], f'{data_set}_{"".join(data_category)}', conf['tag'])

    shutil.rmtree(save_folder, ignore_errors=True)
    os.makedirs(save_folder)
    shutil.rmtree(run_folder, ignore_errors=True)
    os.makedirs(run_folder)

    with open(os.path.join(save_folder, 'config.yaml'), 'w+') as _f:
        yaml.safe_dump(conf, _f)

    data_loader, normal = get_data_loader(**conf['data'], data_category=data_category, device=device,
                                          model_name=model_name)


    train_model(model=model,
                       dataloaders=data_loader,
                       trainer=trainer,
                       optimizer=optimizer,
                       normal=normal,
                       scheduler=scheduler,
                       folder=save_folder,
                       tensorboard_folder=run_folder,
                       device=device,
                       **conf['train'])
    test_model(folder=save_folder,
                      trainer=trainer,
                      model=model,
                      normal=normal,
                      dataloaders=data_loader,
                      device=device)