Пример #1
0
class ExponentialLRScheduler(Callback):
    def __init__(self, gamma, epoch_every=1, batch_every=None):
        super().__init__()
        self.gamma = gamma
        if epoch_every == 0:
            self.epoch_every = False
        else:
            self.epoch_every = epoch_every
        if batch_every == 0:
            self.batch_every = False
        else:
            self.batch_every = batch_every

    def set_params(self, transformer, validation_datagen, *args, **kwargs):
        self.validation_datagen = validation_datagen
        self.model = transformer.model
        self.optimizer = transformer.optimizer
        self.loss_function = transformer.loss_function
        self.lr_scheduler = ExponentialLR(self.optimizer, self.gamma, last_epoch=-1)

    def on_train_begin(self, *args, **kwargs):
        self.epoch_id = 0
        self.batch_id = 0
        logger.info('initial lr: {0}'.format(self.optimizer.state_dict()['param_groups'][0]['initial_lr']))

    def on_epoch_end(self, *args, **kwargs):
        if self.epoch_every and (((self.epoch_id + 1) % self.epoch_every) == 0):
            self.lr_scheduler.step()
            logger.info('epoch {0} current lr: {1}'.format(self.epoch_id + 1,
                                                           self.optimizer.state_dict()['param_groups'][0]['lr']))
        self.epoch_id += 1

    def on_batch_end(self, *args, **kwargs):
        if self.batch_every and ((self.batch_id % self.batch_every) == 0):
            self.lr_scheduler.step()
            logger.info('epoch {0} batch {1} current lr: {2}'.format(
                self.epoch_id + 1, self.batch_id + 1, self.optimizer.state_dict()['param_groups'][0]['lr']))
        self.batch_id += 1
Пример #2
0
    equilibrium = 0.68

    # mse_lambda = 1.0
    # OPTIM-LOSS
    # an optimizer for each of the sub-networks, so we can selectively backprop
    # optimizer_encoder = Adam(params=net.encoder.parameters(),lr = lr,betas=(0.9,0.999))

    optimizer_encoder = RMSprop(params=net.encoder.parameters(),
                                lr=lr,
                                alpha=0.9,
                                eps=1e-8,
                                weight_decay=0,
                                momentum=0,
                                centered=False)
    # lr_encoder = MultiStepLR(optimizer_encoder,milestones=[2],gamma=1)
    lr_encoder = ExponentialLR(optimizer_encoder, gamma=decay_lr)
    # optimizer_decoder = Adam(params=net.decoder.parameters(),lr = lr,betas=(0.9,0.999))
    optimizer_decoder = RMSprop(params=net.decoder.parameters(),
                                lr=lr,
                                alpha=0.9,
                                eps=1e-8,
                                weight_decay=0,
                                momentum=0,
                                centered=False)
    lr_decoder = ExponentialLR(optimizer_decoder, gamma=decay_lr)
    # lr_decoder = MultiStepLR(optimizer_decoder,milestones=[2],gamma=1)
    # optimizer_discriminator = Adam(params=net.discriminator.parameters(),lr = lr,betas=(0.9,0.999))
    optimizer_discriminator = RMSprop(params=net.discriminator.parameters(),
                                      lr=lr,
                                      alpha=0.9,
                                      eps=1e-8,
Пример #3
0
    def __init__(self, cfg, num_classes, feat_dim):

        self.loss_type = cfg.LOSS.LOSS_TYPE
        self.loss_function_map = OrderedDict()

        # loss_function **kw should have:
        #     feat_t,
        #     feat_c,
        #     cls_score,
        #     cls_label,   # label
        #     source_feat_t,
        #     source_feat_c,

        # ID loss
        self.xent = None
        if 'softmax' in self.loss_type:
            self.xent = MyCrossEntropy(
                num_classes=num_classes,
                label_smooth=cfg.LOSS.IF_LABEL_SMOOTH,
                learning_weight=cfg.LOSS.IF_LEARNING_WEIGHT)

            if cfg.MODEL.DEVICE is 'cuda':
                self.xent = self.xent.cuda()

            if self.xent.learning_weight:
                self.xent.optimizer = torch.optim.SGD(self.xent.parameters(),
                                                      lr=0.0001,
                                                      momentum=0.9,
                                                      weight_decay=10**-4,
                                                      nesterov=True)
                self.xent.scheduler = ExponentialLR(self.xent.optimizer,
                                                    gamma=0.95,
                                                    last_epoch=-1)

            def loss_function(**kw):
                return cfg.LOSS.ID_LOSS_WEIGHT * self.xent(
                    kw['cls_score'], kw['cls_label'])

            self.loss_function_map["softmax"] = loss_function

        if 'arcface' in self.loss_type:
            self.arcface = ArcfaceLoss(num_classes=num_classes,
                                       feat_dim=feat_dim)

            if cfg.MODEL.DEVICE is 'cuda':
                self.arcface = self.arcface.cuda()

            def loss_function(**kw):
                return self.arcface(kw['feat_c'], kw['cls_label'])

            self.loss_function_map["arcface"] = loss_function

        # metric loss
        self.triplet = None
        if 'triplet' in self.loss_type:
            self.triplet = TripletLoss(cfg.LOSS.MARGIN, learning_weight=False)

            if cfg.MODEL.DEVICE is 'cuda':
                self.triplet = self.triplet.cuda()

            if self.triplet.learning_weight:
                self.triplet.optimizer = torch.optim.SGD(
                    self.triplet.parameters(),
                    lr=0.0001,
                    momentum=0.9,
                    weight_decay=10**-4,
                    nesterov=True)
                self.triplet.scheduler = ExponentialLR(self.triplet.optimizer,
                                                       gamma=0.95,
                                                       last_epoch=-1)

            def loss_function(**kw):
                return cfg.LOSS.METRIC_LOSS_WEIGHT * self.triplet(
                    kw['feat_t'], kw['cls_label'])

            self.loss_function_map["triplet"] = loss_function

        # cluster loss
        self.center = None
        if cfg.LOSS.IF_WITH_CENTER:
            self.center = CenterLoss(num_classes=num_classes,
                                     feat_dim=feat_dim,
                                     loss_weight=cfg.LOSS.CENTER_LOSS_WEIGHT,
                                     learning_weight=False)

            if cfg.MODEL.DEVICE is 'cuda':
                self.center = self.center.cuda()
            self.center.optimizer = torch.optim.SGD(self.center.parameters(),
                                                    lr=cfg.OPTIMIZER.LOSS_LR,
                                                    momentum=0.9,
                                                    weight_decay=10**-4,
                                                    nesterov=True)

            self.center.scheduler = ExponentialLR(self.center.optimizer,
                                                  gamma=0.995,
                                                  last_epoch=-1)

            def loss_function(**kw):
                return cfg.LOSS.CENTER_LOSS_WEIGHT * self.center(
                    kw['feat_t'], kw['cls_label'])

            self.loss_function_map["center"] = loss_function

            if cfg.LOSS.IF_WITH_DEC:
                self.dec = DECLoss()

                def loss_function(**kw):
                    return self.dec(kw['feat_t'], self.center.centers)

                self.loss_function_map["dec"] = loss_function

        # dist loss
        self.cross_entropy_dist_loss = None
        if cfg.CONTINUATION.IF_ON and "ce_dist" in cfg.CONTINUATION.LOSS_TYPE:
            self.cross_entropy_dist_loss = CrossEntropyDistLoss(
                T=cfg.CONTINUATION.T)

            def loss_function(**kw):
                return self.cross_entropy_dist_loss(kw['feat_c'],
                                                    kw['source_feat_c'])

            self.loss_function_map["ce_dist"] = loss_function

        self.triplet_dist_loss = None
        if cfg.CONTINUATION.IF_ON and "tr_dist" in cfg.CONTINUATION.LOSS_TYPE:
            self.triplet_dist_loss = TripletDistLoss(T=cfg.CONTINUATION.T)

            def loss_function(**kw):
                return self.triplet_dist_loss(kw['feat_t'],
                                              kw['source_feat_t'],
                                              kw['cls_label'])

            self.loss_function_map["tr_dist"] = loss_function

        print(self.loss_function_map.keys())
Пример #4
0
    print(f'{len(ds)} images found for training')


def save_model(path):
    if not deepspeed_utils.is_root_worker():
        return

    save_obj = {'hparams': vae_params, 'weights': vae.state_dict()}

    torch.save(save_obj, path)


# optimizer

opt = Adam(vae.parameters(), lr=LEARNING_RATE)
sched = ExponentialLR(optimizer=opt, gamma=LR_DECAY_RATE)

if deepspeed_utils.is_root_worker():
    # weights & biases experiment tracking

    import wandb

    model_config = dict(num_tokens=NUM_TOKENS,
                        smooth_l1_loss=SMOOTH_L1_LOSS,
                        num_resnet_blocks=NUM_RESNET_BLOCKS,
                        kl_loss_weight=KL_LOSS_WEIGHT)

    run = wandb.init(project='dalle_train_vae',
                     job_type='train_model',
                     config=model_config)
Пример #5
0
def train(ds,
          val_ds,
          fold,
          train_idx,
          val_idx,
          config,
          num_workers=0,
          transforms=None,
          val_transforms=None,
          num_channels_changed=False,
          final_changed=False,
          cycle=False):
    os.makedirs(os.path.join('..', 'weights'), exist_ok=True)
    os.makedirs(os.path.join('..', 'logs'), exist_ok=True)

    save_path = os.path.join('..', 'weights', config.folder)
    model = models[config.network](num_classes=config.num_classes,
                                   num_channels=config.num_channels)
    estimator = Estimator(model,
                          optimizers[config.optimizer],
                          save_path,
                          config=config,
                          num_channels_changed=num_channels_changed,
                          final_changed=final_changed)

    estimator.lr_scheduler = ExponentialLR(
        estimator.optimizer, config.lr_gamma
    )  #LRStepScheduler(estimator.optimizer, config.lr_steps)
    callbacks = [
        ModelSaver(1, ("fold" + str(fold) + "_best.pth"), best_only=True),
        ModelSaver(1, ("fold" + str(fold) + "_last.pth"), best_only=False),
        CheckpointSaver(1, ("fold" + str(fold) + "_checkpoint.pth")),
        # LRDropCheckpointSaver(("fold"+str(fold)+"_checkpoint_e{epoch}.pth")),
        ModelFreezer(),
        # EarlyStopper(10),
        TensorBoard(
            os.path.join('..', 'logs', config.folder, 'fold{}'.format(fold)))
    ]
    # if not num_channels_changed:
    #     callbacks.append(LastCheckpointSaver("fold"+str(fold)+"_checkpoint_rgb.pth", config.nb_epoch))

    hard_neg_miner = None  #HardNegativeMiner(rate=10)
    # metrics = [('dr', dice_round)]

    trainer = PytorchTrain(estimator,
                           fold=fold,
                           callbacks=callbacks,
                           hard_negative_miner=hard_neg_miner)

    train_loader = PytorchDataLoader(TrainDataset(ds,
                                                  train_idx,
                                                  config,
                                                  transforms=transforms),
                                     batch_size=config.batch_size,
                                     shuffle=True,
                                     drop_last=True,
                                     num_workers=num_workers,
                                     pin_memory=True)
    val_loader = PytorchDataLoader(ValDataset(val_ds,
                                              val_idx,
                                              config,
                                              transforms=val_transforms),
                                   batch_size=1,
                                   shuffle=False,
                                   drop_last=False,
                                   num_workers=num_workers,
                                   pin_memory=True)

    trainer.fit(train_loader, val_loader, config.nb_epoch)
class TrackerSiamFC(Tracker):
    def __init__(self, net_path=None, **kwargs):
        super(TrackerSiamFC, self).__init__('SiamFC', True)
        self.cfg = self.parse_args(**kwargs)
        self.update_id = 2
        # setup GPU device if available
        self.cuda = torch.cuda.is_available()
        self.device = torch.device('cuda:0' if self.cuda else 'cpu')

        # setup model
        self.net = Net(backbone=AlexNetV1(), head=SiamFC(self.cfg.out_scale))

        ops.init_weights(self.net)

        # load checkpoint if provided
        if net_path is not None:
            self.net.load_state_dict(
                torch.load(net_path,
                           map_location=lambda storage, loc: storage))
        self.net = self.net.to(self.device)

        # setup criterion
        self.criterion = BalancedLoss()

        # setup optimizer
        self.optimizer = optim.SGD(self.net.parameters(),
                                   lr=self.cfg.initial_lr,
                                   weight_decay=self.cfg.weight_decay,
                                   momentum=self.cfg.momentum)

        # setup lr scheduler
        gamma = np.power(self.cfg.ultimate_lr / self.cfg.initial_lr,
                         1.0 / self.cfg.epoch_num)

        self.lr_scheduler = ExponentialLR(self.optimizer, gamma)

    def parse_args(self, **kwargs):
        # default parameters
        cfg = {
            # basic parameters
            'out_scale': 0.001,
            'exemplar_sz': 127,  #127 x 127
            'instance_sz': 255,
            'context': 0.5,
            # inference parameters
            'scale_num_init': 9,
            'scale_num': 3,
            'scale_step': 1.0375,
            'scale_lr': 0.59,
            'scale_penalty': 0.9745,
            'window_influence': 0.176,
            'response_sz': 17,  #17 x 17
            'response_up': 16,
            'total_stride': 8,
            # train parameters
            'epoch_num': 50,
            'batch_size': 8,
            'num_workers': 32,
            'initial_lr': 1e-2,
            'ultimate_lr': 1e-5,
            'weight_decay': 5e-4,
            'momentum': 0.9,
            'r_pos': 16,
            'r_neg': 0
        }

        for key, val in kwargs.items():
            if key in cfg:
                cfg.update({key: val})
        return namedtuple('Config', cfg.keys())(**cfg)

    @torch.no_grad()
    def init(self, img, box):
        # set to evaluation mode
        self.net.eval()

        # convert box to 0-indexed and center based
        box = np.array([
            box[1] - 1 + (box[3] - 1) / 2, box[0] - 1 +
            (box[2] - 1) / 2, box[3], box[2]
        ],
                       dtype=np.float32)
        self.center, self.target_sz = box[:2], box[2:]

        # create hanning window
        self.upscale_sz = self.cfg.response_up * self.cfg.response_sz
        self.hann_window = np.outer(np.hanning(self.upscale_sz),
                                    np.hanning(self.upscale_sz))
        self.hann_window /= self.hann_window.sum()

        # search scale factors_init
        self.scale_factors_init = self.cfg.scale_step**np.linspace(
            -(self.cfg.scale_num_init // 8), self.cfg.scale_num_init // 8,
            self.cfg.scale_num_init)
        # search scale factors
        self.scale_factors = self.cfg.scale_step**np.linspace(
            -(self.cfg.scale_num // 2), self.cfg.scale_num // 2,
            self.cfg.scale_num)

        # exemplar and search sizes
        context = self.cfg.context * np.sum(self.target_sz)
        self.z_sz = np.sqrt(np.prod(self.target_sz + context))
        self.x_sz = self.z_sz * self.cfg.instance_sz / self.cfg.exemplar_sz

        # exemplar image
        self.avg_color = np.mean(img, axis=(0, 1))
        z = ops.crop_and_resize(img,
                                self.center,
                                self.z_sz,
                                out_size=self.cfg.exemplar_sz,
                                border_value=self.avg_color)

        # exemplar features
        z = torch.from_numpy(z).to(self.device).permute(
            2, 0, 1).unsqueeze(0).float()
        self.kernel = self.net.backbone(z)  #AlexNetV1()

    @torch.no_grad()
    def update(self, count, img_files):
        # set to evaluation mode
        self.net.eval()
        x = []
        Listresponses = []

        #setting scale_num and scale_factors
        if count == 1:
            scale_num = self.cfg.scale_num_init
            scale_factors = self.scale_factors_init
            panalty_score = 8
            for fr, img_file in enumerate(img_files):
                img = ops.read_image(img_file)
                x.append(
                    ops.crop_and_resize(img,
                                        self.center,
                                        self.x_sz * 1.0,
                                        out_size=self.cfg.instance_sz,
                                        border_value=self.avg_color))
        else:
            scale_num = self.cfg.scale_num
            scale_factors = self.scale_factors
            panalty_score = 2
            for fr, img_file in enumerate(img_files):
                img = ops.read_image(img_file)
                x = [
                    ops.crop_and_resize(img,
                                        self.center,
                                        self.x_sz * f,
                                        out_size=self.cfg.instance_sz,
                                        border_value=self.avg_color)
                    for f in scale_factors
                ]

                # search images
                x_r = np.stack(x, axis=0)
                x_r = torch.from_numpy(x_r).to(self.device).permute(
                    0, 3, 1, 2).float()

                # responses
                x_r = self.net.backbone(x_r)  #AlexNetV1()
                responses = self.net.head(self.kernel, x_r)
                responses = responses.squeeze(1).cpu().numpy()

                # upsample responses and penalize scale changes
                responses = np.stack([
                    cv2.resize(u, (self.upscale_sz, self.upscale_sz),
                               interpolation=cv2.INTER_CUBIC)
                    for u in responses
                ])  #responses.shape = 3x272x272
                responses[:scale_num //
                          panalty_score] *= self.cfg.scale_penalty
                responses[scale_num // panalty_score +
                          1:] *= self.cfg.scale_penalty

                # peak scale
                pc = np.amax(responses, axis=(1, 2))
                scale_id = np.argmax(np.amax(responses, axis=(1, 2)))
                Listresponses.append(x[scale_id])

            x = Listresponses

        # search images
        search = x
        x = np.stack(x, axis=0)
        x = torch.from_numpy(x).to(self.device).permute(0, 3, 1, 2).float()

        # responses
        x = self.net.backbone(x)  #AlexNetV1()
        responses = self.net.head(self.kernel, x)
        responses = responses.squeeze(1).cpu().numpy()

        # upsample responses and penalize scale changes
        responses = np.stack([
            cv2.resize(u, (self.upscale_sz, self.upscale_sz),
                       interpolation=cv2.INTER_CUBIC) for u in responses
        ])
        responses[:scale_num // panalty_score] *= self.cfg.scale_penalty
        responses[scale_num // panalty_score + 1:] *= self.cfg.scale_penalty

        # peak scale
        pc = np.amax(responses, axis=(1, 2))

        scale_id = np.argmax(np.amax(responses, axis=(1, 2)))

        if count == 1:
            png = img_files[scale_id].split('\\')[-1]
            png_id = int(png.split('.')[-2])
        else:
            png = img_files[scale_id].split('\\')[-1]
            png_id = int(png.split('.')[-2])

        # peak location
        response = responses[scale_id]
        response -= response.min()
        response /= response.sum() + 1e-16
        response = (1 - self.cfg.window_influence
                    ) * response + self.cfg.window_influence * self.hann_window
        loc = np.unravel_index(response.argmax(), response.shape)

        # locate target center
        disp_in_response = np.array(loc) - (self.upscale_sz - 1) / 2
        disp_in_instance = disp_in_response * self.cfg.total_stride / self.cfg.response_up
        disp_in_image = disp_in_instance * self.x_sz * scale_factors[
            scale_id] / self.cfg.instance_sz
        self.center += disp_in_image

        scale = (1 - self.cfg.scale_lr
                 ) * 1.0 + self.cfg.scale_lr * scale_factors[scale_id]
        self.target_sz *= scale
        self.z_sz *= scale
        self.x_sz *= scale

        # return 1-indexed and left-top based bounding box
        box = np.array([
            self.center[1] + 1 - (self.target_sz[1] - 1) / 2,
            self.center[0] + 1 - (self.target_sz[0] - 1) / 2,
            self.target_sz[1], self.target_sz[0]
        ])

        return box, responses, search, png_id

    def track(self, first_img, focal_dirs, show_imgs, box, visualize=False):
        frame_num = len(focal_dirs)
        boxes = np.zeros((frame_num, 4))
        boxes[0] = box
        times = np.zeros(frame_num)
        response = np.zeros((3, 272, 272))

        gt = open("recode/bbox.txt", 'w')
        try:
            for f, focal_dir in enumerate(focal_dirs):
                focal_plane = glob.glob(focal_dir + "/*")

                begin = time.time()
                if f == 0:
                    self.init(first_img, box)
                elif f == 1:
                    sharp = sharpfiles(focal_dir)
                    boxes[f, :], response, search, scale_id = self.update(
                        f, sharp)  #picked
                else:
                    focal_planes = focal_plane[scale_id - 1:scale_id + 1 + 1]
                    boxes[f, :], response, search, scale_id = self.update(
                        f, focal_planes)  #focal 이미지 리스트
                times[f] = time.time() - begin

                if response[0][0][0] == 0.0:
                    visualize = False
                else:
                    visualize = True

                if visualize == True:
                    dimg = ops.read_image(show_imgs[f])
                    save_img, centerloc = ops.show_image(dimg,
                                                         boxes[f, :],
                                                         None,
                                                         num_name=f)
                    save_img = cv2.putText(save_img, str(scale_id), (100, 30),
                                           cv2.FONT_HERSHEY_SIMPLEX, 0.5,
                                           (0, 255, 0), 2)
                    savepath = "./image/{0:0=3d}.png".format(f)
                    cv2.imwrite(savepath, save_img)
                    #groundtruth
                    center_coo = "%d,%d,%d,%d\n" % (int(
                        centerloc[0]), int(centerloc[1]), int(
                            centerloc[2]), int(centerloc[3]))
                    gt.write(center_coo)

        finally:
            gt.close()

        return boxes, times, response

    def train_step(self, batch, backward=True):
        # set network mode
        self.net.train(backward)

        # parse batch data
        z = batch[0].to(self.device, non_blocking=self.cuda)
        x = batch[1].to(self.device, non_blocking=self.cuda)

        with torch.set_grad_enabled(backward):
            # inference
            responses = self.net(z, x)

            # calculate loss
            labels = self._create_labels(responses.size())
            loss = self.criterion(responses, labels)

            if backward:
                # back propagation
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

        return loss.item()

    @torch.enable_grad()
    def train_over(self, seqs, val_seqs=None, save_dir='pretrained'):
        # set to train mode
        self.net.train()

        # create save_dir folder
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        # setup dataset
        transforms = SiamFCTransforms(exemplar_sz=self.cfg.exemplar_sz,
                                      instance_sz=self.cfg.instance_sz,
                                      context=self.cfg.context)
        dataset = Pair(seqs=seqs, transforms=transforms)

        # setup dataloader
        dataloader = DataLoader(dataset,
                                batch_size=self.cfg.batch_size,
                                shuffle=True,
                                num_workers=self.cfg.num_workers,
                                pin_memory=self.cuda,
                                drop_last=True)

        # loop over epochs
        for epoch in range(self.cfg.epoch_num):
            # update lr at each epoch
            self.lr_scheduler.step(epoch=epoch)

            # loop over dataloader
            for it, batch in enumerate(dataloader):
                loss = self.train_step(batch, backward=True)
                print('Epoch: {} [{}/{}] Loss: {:.5f}'.format(
                    epoch + 1, it + 1, len(dataloader), loss))
                sys.stdout.flush()

            # save checkpoint
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            net_path = os.path.join(save_dir,
                                    'siamfc_alexnet_e%d.pth' % (epoch + 1))
            torch.save(self.net.state_dict(), net_path)

    def _create_labels(self, size):
        # skip if same sized labels already created
        if hasattr(self, 'labels') and self.labels.size() == size:
            return self.labels

        def logistic_labels(x, y, r_pos, r_neg):
            dist = np.abs(x) + np.abs(y)  # block distance
            labels = np.where(
                dist <= r_pos, np.ones_like(x),
                np.where(dist < r_neg,
                         np.ones_like(x) * 0.5, np.zeros_like(x)))
            return labels

        # distances along x- and y-axis
        n, c, h, w = size
        x = np.arange(w) - (w - 1) / 2
        y = np.arange(h) - (h - 1) / 2
        x, y = np.meshgrid(x, y)

        # create logistic labels
        r_pos = self.cfg.r_pos / self.cfg.total_stride
        r_neg = self.cfg.r_neg / self.cfg.total_stride
        labels = logistic_labels(x, y, r_pos, r_neg)

        # repeat to size
        labels = labels.reshape((1, 1, h, w))
        labels = np.tile(labels, (n, c, 1, 1))

        # convert to tensors
        self.labels = torch.from_numpy(labels).to(self.device).float()

        return self.labels
Пример #7
0
def exp(parameters):
    if not ("gamma" in parameters["scheduler"]):
        parameters["scheduler"]["gamma"] = 0.1
    return ExponentialLR(
        parameters["optimizer_object"], parameters["scheduler"]["gamma"]
    )
Пример #8
0
    model = MultiLabelClassifier(large=large).cuda()
    # model = nn.DataParallel(model)
    params = model.parameters()
    optimizer = Adam(params, lr=1e-4, weight_decay=1e-6)

    if start_epoch > 0:
        checkpoints = torch.load(
            os.path.join(checkpoints_dir,
                         'model-epoch{}.pth'.format(start_epoch)))
        model.load_state_dict(checkpoints['model'])
        # score_model.load_state_dict(checkpoints['score'])
        # item_embedding.load_state_dict(checkpoints['item'])
        optimizer.load_state_dict(checkpoints['optimizer'])
        print("load checkpoints")
    # model = torch.nn.DataParallel(model)
    scheduler = ExponentialLR(optimizer, 0.96, last_epoch=start_epoch - 1)
    for epoch in range(start_epoch, 20):
        tbar = tqdm(loader)
        losses_clf = 0.
        losses_gen = 0.
        model.train()
        for i, (query, query_len, features, boxes, obj_len) in enumerate(tbar):
            query = query.cuda()
            features = features.cuda()
            obj_len = obj_len.cuda()
            boxes = boxes.cuda()
            optimizer.zero_grad()
            _, loss_clf, loss_gen = model(features, boxes, obj_len, query)
            loss_clf = loss_clf.mean()
            loss_gen = loss_gen.mean()
            loss = loss_clf + loss_gen
Пример #9
0
            # ACIQ initialization
            laplace = {1: 2.83, 2: 3.89, 3: 5.05, 4: 6.2, 5: 7.41, 6: 8.64, 7: 9.89, 8: 11.16}
            gaus = {1: 1.71, 2: 2.15, 3: 2.55, 4: 2.93, 5: 3.28, 6: 3.61, 7: 3.92, 8: 4.2}

            #            ops.c.data = ops.running_mean + (ops.running_b * laplace[args.actBitwidth])
            ops.c.data = ops.running_mean + (3 * ops.running_std)

            if not args.gradual:
                ops.quant = True

    if len(args.gpu) > 1:  # parallel
        model = torch.nn.DataParallel(model.module, args.gpu)
        model = model.cuda()

    gamma = (0.01 ** (1 / float(args.epochs - args.gradEpochs)))
    scheduler = ExponentialLR(optimizer, gamma=gamma)


    # log command line
    logging.info('CommandLine: {} PID: {} '
                 'Hostname: {} CUDA_VISIBLE_DEVICES {}'.format(argv, getpid(), gethostname(),
                                                               environ.get('CUDA_VISIBLE_DEVICES')))

    # mlflow
    mlflow.set_tracking_uri(os.path.join(baseFolder, 'mlruns_mxt'))

    mlflow.set_experiment(args.exp + '_' + args.model)

    if args.plotHist:
        global saveDic
        global num
Пример #10
0
 def configure_optimizers(self):
     opt = Adam(params=self.model.parameters(), lr=self.hparams.learning_rate)
     scheduler = ExponentialLR(opt, gamma=self.hparams.gamma)
     return [opt], [scheduler]
 def set_params(self, transformer, validation_datagen, *args, **kwargs):
     self.validation_datagen = validation_datagen
     self.model = transformer.model
     self.optimizer = transformer.optimizer
     self.loss_function = transformer.loss_function
     self.lr_scheduler = ExponentialLR(self.optimizer, self.gamma, last_epoch=-1)
Пример #12
0
        "params": MODEL.stem.parameters(),
        'lr': 0.0004
    },
    {
        "params": MODEL.features.parameters(),
        'lr': 0.0009
    },
    {
        "params": MODEL.classifier.parameters(),
        'lr': 0.09
    },
],
            momentum=0.9)

LR_SCHEDULERS = [
    ExponentialLR(OPTIM, gamma=0.9),
]

REDUCE_LR_ON_PLATEAU = ReduceLROnPlateau(OPTIM,
                                         mode='min',
                                         factor=0.5,
                                         patience=5,
                                         threshold=0.05,
                                         verbose=True)

EARLY_STOPPING_KWARGS = {
    'patience': 30,
    # 'score_function': None
}

LOG_INTERVAL = 100
Пример #13
0
                           train=True,
                           download=download,
                           transform=train_transform)
    train_data.train_data = common_transform(train_data.train_data)
    test_data = smallNORB(data_folder,
                          train=False,
                          download=download,
                          transform=test_transform)
    test_data.test_data = common_transform(test_data.test_data)
    train_loader = DataLoader(train_data, batch_size, True, num_workers=n_cpu)
    train_loader = loader_wrapper(train_loader)
    test_loader = DataLoader(test_data, batch_size, False, num_workers=n_cpu)

    model = nn.DataParallel(CapsuleNet()).to(device)
    optimizer = Adam(model.parameters(), 3e-3, weight_decay=.0000002)
    scheduler = ExponentialLR(optimizer, 0.96)

    margin = lambda n: 0.2 + .79 / (1 + math.exp(-(min(10.0, n / 50000.0 - 4)))
                                    )

    model.train()
    i_iter = 0
    while i_iter < n_iter:
        if i_iter % 20000 == 0:
            scheduler.step()

            model.eval()
            with torch.no_grad():
                total = correct = 0
                for xs, ys in test_loader:
                    xs = xs.to(device)
Пример #14
0
 def optim_def(self):
     trainable = filter(lambda p: p.requires_grad, self.student.parameters())
     self.optim_student = optim.Adam(trainable, lr=self.learning_rate)
     if self.decay_rate > 0:
         self.scheduler = ExponentialLR(self.optim_student, self.decay_rate)
Пример #15
0
class Trainer_KBQA(object):
    def __init__(self, args, logger=None):
        self.args = args
        self.logger = logger
        self.best_dev_performance = 0.0
        self.best_h1 = 0.0
        self.best_f1 = 0.0
        self.eps = args['eps']
        self.learning_rate = self.args['lr']
        self.test_batch_size = args['test_batch_size']
        self.device = torch.device('cuda' if args['use_cuda'] else 'cpu')
        self.train_kl = args['train_KL']
        self.num_step = args['num_step']
        self.use_label = args['use_label']
        self.reset_time = 0
        self.load_data(args)
        if 'decay_rate' in args:
            self.decay_rate = args['decay_rate']
        else:
            self.decay_rate = 0.98
        # self.mode = args['mode']
        # self.use_middle = args['use_middle']
        self.mode = "teacher"
        self.student = init_parallel(self.args, self.logger, len(self.entity2id), self.num_kb_relation,
                                     len(self.word2id))
        self.student.to(self.device)
        self.evaluator = Evaluator_nsm(args=args, student=self.student, entity2id=self.entity2id,
                                       relation2id=self.relation2id, device=self.device)
        self.load_pretrain()
        self.optim_def()

    def optim_def(self):
        trainable = filter(lambda p: p.requires_grad, self.student.parameters())
        self.optim_student = optim.Adam(trainable, lr=self.learning_rate)
        if self.decay_rate > 0:
            self.scheduler = ExponentialLR(self.optim_student, self.decay_rate)

    def load_data(self, args):
        dataset = load_data(args)
        self.train_data = dataset["train"]
        self.valid_data = dataset["valid"]
        self.test_data = dataset["test"]
        self.entity2id = dataset["entity2id"]
        self.relation2id = dataset["relation2id"]
        self.word2id = dataset["word2id"]
        self.num_kb_relation = self.test_data.num_kb_relation
        self.num_entity = len(self.entity2id)

    def load_pretrain(self):
        args = self.args
        if args['load_pretrain'] is not None:
            filename = os.path.join(args['checkpoint_dir'], args['load_pretrain'])
            print("Load ckpt from", filename)
            checkpoint = torch.load(filename)
            model_state_dict = checkpoint["model_state_dict"]
            model = self.student.model
            self.logger.info("Load param of {} from {}.".format(", ".join(list(model_state_dict.keys())), filename))
            model.load_state_dict(model_state_dict, strict=False)

    def evaluate(self, data, test_batch_size=20, mode="teacher", write_info=False):
        return self.evaluator.evaluate(data, test_batch_size, write_info)

    def train(self, start_epoch, end_epoch):
        # self.load_pretrain()
        eval_every = self.args['eval_every']
        # eval_acc = inference(self.model, self.valid_data, self.entity2id, self.args)
        self.evaluate(self.valid_data, self.test_batch_size, mode="teacher")
        print("Strat Training------------------")
        for epoch in range(start_epoch, end_epoch + 1):
            st = time.time()
            loss, extras, h1_list_all, f1_list_all = self.train_epoch()
            if self.decay_rate > 0:
                self.scheduler.step()
            # if self.mode == "student":
            #     self.student.update_target()
            # actor_loss, ent_loss = extras
            self.logger.info("Epoch: {}, loss : {:.4f}, time: {}".format(epoch + 1, loss, time.time() - st))
            self.logger.info("Training h1 : {:.4f}, f1 : {:.4f}".format(np.mean(h1_list_all), np.mean(f1_list_all)))
            extra_list = ["{}: {:.4f}".format(extra_item, np.mean(extras[extra_item])) for extra_item in extras]
            extra_str = " ".join(extra_list)
            self.logger.info(extra_str)
            # print("actor : {:.4f}, ent : {:.4f}".format(actor_loss, ent_loss))
            if (epoch + 1) % eval_every == 0 and epoch + 1 > 0:
                eval_f1, eval_h1 = self.evaluate(self.valid_data, self.test_batch_size, mode="teacher")
                # eval_f1 = np.mean(f1_list_all)
                # eval_h1 = np.mean(h1_list_all)
                self.logger.info("EVAL F1: {:.4f}, H1: {:.4f}".format(eval_f1, eval_h1))
                if eval_h1 > self.best_h1:
                    self.best_h1 = eval_h1
                    self.save_ckpt("h1")
                if eval_f1 > self.best_f1:
                    self.best_f1 = eval_f1
                    self.save_ckpt("f1")
                # self.reset_time = 0
                # else:
                #     self.logger.info('No improvement after one evaluation iter.')
                #     self.reset_time += 1
                # if self.reset_time >= 5:
                #     self.logger.info('No improvement after 5 evaluation. Early Stopping.')
                #     break
        self.save_ckpt("final")
        self.logger.info('Train Done! Evaluate on testset with saved model')
        print("End Training------------------")
        self.evaluate_best(self.mode)

    def evaluate_best(self, mode):
        filename = os.path.join(self.args['checkpoint_dir'], "{}-h1.ckpt".format(self.args['experiment_name']))
        self.load_ckpt(filename)
        eval_f1, eval_h1 = self.evaluate(self.test_data, self.test_batch_size, mode="teacher", write_info=False)
        self.logger.info("Best h1 evaluation")
        self.logger.info("TEST F1: {:.4f}, H1: {:.4f}".format(eval_f1, eval_h1))

        filename = os.path.join(self.args['checkpoint_dir'], "{}-f1.ckpt".format(self.args['experiment_name']))
        self.load_ckpt(filename)
        eval_f1, eval_h1 = self.evaluate(self.test_data, self.test_batch_size, mode="teacher", write_info=False)
        self.logger.info("Best f1 evaluation")
        self.logger.info("TEST F1: {:.4f}, H1: {:.4f}".format(eval_f1, eval_h1))

        filename = os.path.join(self.args['checkpoint_dir'], "{}-final.ckpt".format(self.args['experiment_name']))
        self.load_ckpt(filename)
        eval_f1, eval_h1 = self.evaluate(self.test_data, self.test_batch_size, mode="teacher", write_info=False)
        self.logger.info("Final evaluation")
        self.logger.info("TEST F1: {:.4f}, H1: {:.4f}".format(eval_f1, eval_h1))

    def evaluate_single(self, filename):
        if filename is not None:
            self.load_ckpt(filename)
        test_f1, test_hits = self.evaluate(self.test_data, self.test_batch_size, mode=self.mode, write_info=True)
        self.logger.info("TEST F1: {:.4f}, H1: {:.4f}".format(test_f1, test_hits))

    def train_epoch(self):
        self.student.train()
        self.train_data.reset_batches(is_sequential=False)
        losses = []
        actor_losses = []
        ent_losses = []
        num_epoch = math.ceil(self.train_data.num_data / self.args['batch_size'])
        h1_list_all = []
        f1_list_all = []
        extra_dict = {}
        extra_item_list = ["main", "back", "constrain"]
        for srt_ in extra_item_list:
            extra_dict[srt_] = []
        for iteration in tqdm(range(num_epoch)):
            batch = self.train_data.get_batch(iteration, self.args['batch_size'], self.args['fact_drop'])
            # label_dist, label_valid = self.train_data.get_label()
            # loss = self.train_step_student(batch, label_dist, label_valid)
            self.optim_student.zero_grad()
            loss, extras, _, tp_list = self.student(batch, training=True)
            for i, extra_item in enumerate(extra_item_list):
                extra_dict[extra_item].append(extras[i])
            h1_list, f1_list = tp_list
            h1_list_all.extend(h1_list)
            f1_list_all.extend(f1_list)
            loss.backward()
            torch.nn.utils.clip_grad_norm_([param for name, param in self.student.named_parameters()],
                                           self.args['gradient_clip'])
            self.optim_student.step()
            losses.append(loss.item())
        return np.mean(losses), extra_dict, h1_list_all, f1_list_all

    def save_ckpt(self, reason="h1"):
        model = self.student
        checkpoint = {
            'model_state_dict': model.state_dict()
        }
        model_name = os.path.join(self.args['checkpoint_dir'], "{}-{}.ckpt".format(self.args['experiment_name'],
                                                                                   reason))
        torch.save(checkpoint, model_name)
        print("Best %s, save model as %s" %(reason, model_name))

    def load_ckpt(self, filename):
        checkpoint = torch.load(filename)
        model_state_dict = checkpoint["model_state_dict"]
        model = self.student
        # model = self.student
        self.logger.info("Load param of {} from {}.".format(", ".join(list(model_state_dict.keys())), filename))
        model.load_state_dict(model_state_dict, strict=False)
def main(_run, _config, _seed, _log):
    """

    :param _run:
    :param _config:
    :param _seed:
    :param _log:
    :return:
    """
    """
    Setting and loading parameters
    """
    # Setting logger
    args = _config
    logger = _log

    logger.info(args)
    logger.info('It started at: %s' % datetime.now())

    torch.manual_seed(_seed)
    bugReportDatabase = BugReportDatabase.fromJson(args['bug_database'])
    paddingSym = "</s>"
    batchSize = args['batch_size']

    device = torch.device('cuda' if args['cuda'] else "cpu")

    if args['cuda']:
        logger.info("Turning CUDA on")
    else:
        logger.info("Turning CUDA off")

    # It is the folder where the preprocessed information will be stored.
    cacheFolder = args['cache_folder']

    # Setting the parameter to save and loading parameters
    importantParameters = ['compare_aggregation', 'categorical']
    parametersToSave = dict([(parName, args[parName])
                             for parName in importantParameters])

    if args['load'] is not None:
        mapLocation = (
            lambda storage, loc: storage.cuda()) if args['cuda'] else 'cpu'
        modelInfo = torch.load(args['load'], map_location=mapLocation)
        modelState = modelInfo['model']

        for paramName, paramValue in modelInfo['params'].items():
            args[paramName] = paramValue
    else:
        modelState = None

    preprocessors = PreprocessorList()
    inputHandlers = []

    categoricalOpt = args.get('categorical')

    if categoricalOpt is not None and len(categoricalOpt) != 0:
        categoricalEncoder, _, _ = processCategoricalParam(
            categoricalOpt, bugReportDatabase, inputHandlers, preprocessors,
            None, logger)
    else:
        categoricalEncoder = None

    filterInputHandlers = []

    compareAggOpt = args['compare_aggregation']
    databasePath = args['bug_database']

    # Loading word embedding
    if compareAggOpt["lexicon"]:
        emb = np.load(compareAggOpt["word_embedding"])

        lexicon = Lexicon(unknownSymbol=None)
        with codecs.open(compareAggOpt["lexicon"]) as f:
            for l in f:
                lexicon.put(l.strip())

        lexicon.setUnknown("UUUKNNN")
        paddingId = lexicon.getLexiconIndex(paddingSym)
        embedding = Embedding(lexicon, emb, paddingIdx=paddingId)

        logger.info("Lexicon size: %d" % (lexicon.getLen()))
        logger.info("Word Embedding size: %d" % (embedding.getEmbeddingSize()))
    elif compareAggOpt["word_embedding"]:
        # todo: Allow use embeddings and other representation
        lexicon, embedding = Embedding.fromFile(
            compareAggOpt['word_embedding'],
            'UUUKNNN',
            hasHeader=False,
            paddingSym=paddingSym)
        logger.info("Lexicon size: %d" % (lexicon.getLen()))
        logger.info("Word Embedding size: %d" % (embedding.getEmbeddingSize()))
        paddingId = lexicon.getLexiconIndex(paddingSym)
    else:
        embedding = None

    if compareAggOpt["norm_word_embedding"]:
        embedding.zscoreNormalization()

    # Tokenizer
    if compareAggOpt['tokenizer'] == 'default':
        logger.info("Use default tokenizer to tokenize summary information")
        tokenizer = MultiLineTokenizer()
    elif compareAggOpt['tokenizer'] == 'white_space':
        logger.info(
            "Use white space tokenizer to tokenize summary information")
        tokenizer = WhitespaceTokenizer()
    else:
        raise ArgumentError(
            "Tokenizer value %s is invalid. You should choose one of these: default and white_space"
            % compareAggOpt['tokenizer'])

    # Preparing input handlers, preprocessors and cache
    minSeqSize = max(compareAggOpt['aggregate']["window"]
                     ) if compareAggOpt['aggregate']["model"] == "cnn" else -1
    bow = compareAggOpt.get('bow', False)
    freq = compareAggOpt.get('frequency', False) and bow

    logger.info("BoW={} and TF={}".format(bow, freq))

    if compareAggOpt['extractor'] is not None:
        # Use summary and description (concatenated) to address this problem
        logger.info("Using Summary and Description information.")
        # Loading Filters
        extractorFilters = loadFilters(compareAggOpt['extractor']['filters'])

        arguments = (databasePath, compareAggOpt['word_embedding'],
                     str(compareAggOpt['lexicon']), ' '.join(
                         sorted([
                             fil.__class__.__name__ for fil in extractorFilters
                         ])), compareAggOpt['tokenizer'], str(bow), str(freq),
                     SABDEncoderPreprocessor.__name__)

        inputHandlers.append(SABDInputHandler(paddingId, minSeqSize))
        extractorCache = PreprocessingCache(cacheFolder, arguments)

        if bow:
            extractorPreprocessor = SABDBoWPreprocessor(
                lexicon, bugReportDatabase, extractorFilters, tokenizer,
                paddingId, freq, extractorCache)
        else:
            extractorPreprocessor = SABDEncoderPreprocessor(
                lexicon, bugReportDatabase, extractorFilters, tokenizer,
                paddingId, extractorCache)
        preprocessors.append(extractorPreprocessor)

    # Create model
    model = SABD(embedding, categoricalEncoder, compareAggOpt['extractor'],
                 compareAggOpt['matching'], compareAggOpt['aggregate'],
                 compareAggOpt['classifier'], freq)

    if args['loss'] == 'bce':
        logger.info("Using BCE Loss: margin={}".format(args['margin']))
        lossFn = BCELoss()
        lossNoReduction = BCELoss(reduction='none')
        cmp_collate = PairBugCollate(inputHandlers,
                                     torch.float32,
                                     unsqueeze_target=True)
    elif args['loss'] == 'triplet':
        logger.info("Using Triplet Loss: margin={}".format(args['margin']))
        lossFn = TripletLoss(args['margin'])
        lossNoReduction = TripletLoss(args['margin'], reduction='none')
        cmp_collate = TripletBugCollate(inputHandlers)

    model.to(device)

    if modelState:
        model.load_state_dict(modelState)
    """
    Loading the training and validation. Also, it sets how the negative example will be generated.
    """
    # load training
    if args.get('pairs_training'):
        negativePairGenOpt = args.get('neg_pair_generator', )
        trainingFile = args.get('pairs_training')

        offlineGeneration = not (negativePairGenOpt is None
                                 or negativePairGenOpt['type'] == 'none')
        masterIdByBugId = bugReportDatabase.getMasterIdByBugId()
        randomAnchor = negativePairGenOpt['random_anchor']

        if not offlineGeneration:
            logger.info("Not generate dynamically the negative examples.")
            negativePairGenerator = None
        else:
            pairGenType = negativePairGenOpt['type']

            if pairGenType == 'random':
                logger.info("Random Negative Pair Generator")
                trainingDataset = BugDataset(negativePairGenOpt['training'])
                bugIds = trainingDataset.bugIds

                logger.info(
                    "Using the following dataset to generate negative examples: %s. Number of bugs in the training: %d"
                    % (trainingDataset.info, len(bugIds)))

                negativePairGenerator = RandomGenerator(
                    preprocessors,
                    cmp_collate,
                    negativePairGenOpt['rate'],
                    bugIds,
                    masterIdByBugId,
                    randomAnchor=randomAnchor)

            elif pairGenType == 'non_negative':
                logger.info("Non Negative Pair Generator")
                trainingDataset = BugDataset(negativePairGenOpt['training'])
                bugIds = trainingDataset.bugIds

                logger.info(
                    "Using the following dataset to generate negative examples: %s. Number of bugs in the training: %d"
                    % (trainingDataset.info, len(bugIds)))

                negativePairGenerator = NonNegativeRandomGenerator(
                    preprocessors,
                    cmp_collate,
                    negativePairGenOpt['rate'],
                    bugIds,
                    masterIdByBugId,
                    negativePairGenOpt['n_tries'],
                    device,
                    randomAnchor=randomAnchor)
            elif pairGenType == 'misc_non_zero':
                logger.info("Misc Non Zero Pair Generator")
                trainingDataset = BugDataset(negativePairGenOpt['training'])
                bugIds = trainingDataset.bugIds

                logger.info(
                    "Using the following dataset to generate negative examples: %s. Number of bugs in the training: %d"
                    % (trainingDataset.info, len(bugIds)))

                negativePairGenerator = MiscNonZeroRandomGen(
                    preprocessors,
                    cmp_collate,
                    negativePairGenOpt['rate'],
                    bugIds,
                    trainingDataset.duplicateIds,
                    masterIdByBugId,
                    negativePairGenOpt['n_tries'],
                    device,
                    randomAnchor=randomAnchor)
            elif pairGenType == 'product_component':
                logger.info("Product Component Pair Generator")
                trainingDataset = BugDataset(negativePairGenOpt['training'])
                bugIds = trainingDataset.bugIds

                logger.info(
                    "Using the following dataset to generate negative examples: %s. Number of bugs in the training: %d"
                    % (trainingDataset.info, len(bugIds)))

                negativePairGenerator = ProductComponentRandomGen(
                    bugReportDatabase,
                    preprocessors,
                    cmp_collate,
                    negativePairGenOpt['rate'],
                    bugIds,
                    masterIdByBugId,
                    negativePairGenOpt['n_tries'],
                    device,
                    randomAnchor=randomAnchor)

            elif pairGenType == 'random_k':
                logger.info("Random K Negative Pair Generator")
                trainingDataset = BugDataset(negativePairGenOpt['training'])
                bugIds = trainingDataset.bugIds

                logger.info(
                    "Using the following dataset to generate negative examples: %s. Number of bugs in the training: %d"
                    % (trainingDataset.info, len(bugIds)))

                negativePairGenerator = KRandomGenerator(
                    preprocessors,
                    cmp_collate,
                    negativePairGenOpt['rate'],
                    bugIds,
                    masterIdByBugId,
                    negativePairGenOpt['k'],
                    device,
                    randomAnchor=randomAnchor)
            elif pairGenType == "pre":
                logger.info("Pre-selected list generator")
                negativePairGenerator = PreSelectedGenerator(
                    negativePairGenOpt['pre_list_file'],
                    preprocessors,
                    negativePairGenOpt['rate'],
                    masterIdByBugId,
                    negativePairGenOpt['preselected_length'],
                    randomAnchor=randomAnchor)

            elif pairGenType == "positive_pre":
                logger.info("Positive Pre-selected list generator")
                negativePairGenerator = PositivePreSelectedGenerator(
                    negativePairGenOpt['pre_list_file'],
                    preprocessors,
                    cmp_collate,
                    negativePairGenOpt['rate'],
                    masterIdByBugId,
                    negativePairGenOpt['preselected_length'],
                    randomAnchor=randomAnchor)
            elif pairGenType == "misc_non_zero_pre":
                logger.info("Misc: non-zero and Pre-selected list generator")
                negativePairGenerator1 = PreSelectedGenerator(
                    negativePairGenOpt['pre_list_file'],
                    preprocessors,
                    negativePairGenOpt['rate'],
                    masterIdByBugId,
                    negativePairGenOpt['preselected_length'],
                    randomAnchor=randomAnchor)

                trainingDataset = BugDataset(negativePairGenOpt['training'])
                bugIds = trainingDataset.bugIds

                negativePairGenerator2 = NonNegativeRandomGenerator(
                    preprocessors,
                    cmp_collate,
                    negativePairGenOpt['rate'],
                    bugIds,
                    masterIdByBugId,
                    negativePairGenOpt['n_tries'],
                    device,
                    randomAnchor=randomAnchor)

                negativePairGenerator = MiscOfflineGenerator(
                    (negativePairGenerator1, negativePairGenerator2))
            elif pairGenType == "misc_non_zero_positive_pre":
                logger.info(
                    "Misc: non-zero and Positive Pre-selected list generator")
                negativePairGenerator1 = PositivePreSelectedGenerator(
                    negativePairGenOpt['pre_list_file'],
                    preprocessors,
                    cmp_collate,
                    negativePairGenOpt['rate'],
                    masterIdByBugId,
                    negativePairGenOpt['preselected_length'],
                    randomAnchor=randomAnchor)

                trainingDataset = BugDataset(negativePairGenOpt['training'])
                bugIds = trainingDataset.bugIds

                negativePairGenerator2 = NonNegativeRandomGenerator(
                    preprocessors,
                    cmp_collate,
                    negativePairGenOpt['rate'],
                    bugIds,
                    masterIdByBugId,
                    negativePairGenOpt['n_tries'],
                    device,
                    randomAnchor=randomAnchor)

                negativePairGenerator = MiscOfflineGenerator(
                    (negativePairGenerator1, negativePairGenerator2))

            else:
                raise ArgumentError(
                    "Offline generator is invalid (%s). You should choose one of these: random, hard and pre"
                    % pairGenType)

        if isinstance(lossFn, BCELoss):
            training_reader = PairBugDatasetReader(
                trainingFile,
                preprocessors,
                negativePairGenerator,
                randomInvertPair=args['random_switch'])
        elif isinstance(lossFn, TripletLoss):
            training_reader = TripletBugDatasetReader(
                trainingFile,
                preprocessors,
                negativePairGenerator,
                randomInvertPair=args['random_switch'])

        trainingLoader = DataLoader(training_reader,
                                    batch_size=batchSize,
                                    collate_fn=cmp_collate.collate,
                                    shuffle=True)
        logger.info("Training size: %s" % (len(trainingLoader.dataset)))

    # load validation
    if args.get('pairs_validation'):
        if isinstance(lossFn, BCELoss):
            validation_reader = PairBugDatasetReader(
                args.get('pairs_validation'), preprocessors)
        elif isinstance(lossFn, TripletLoss):
            validation_reader = TripletBugDatasetReader(
                args.get('pairs_validation'), preprocessors)

        validationLoader = DataLoader(validation_reader,
                                      batch_size=batchSize,
                                      collate_fn=cmp_collate.collate)

        logger.info("Validation size: %s" % (len(validationLoader.dataset)))
    else:
        validationLoader = None
    """
    Training and evaluate the model. 
    """
    optimizer_opt = args.get('optimizer', 'adam')

    if optimizer_opt == 'sgd':
        logger.info('SGD')
        optimizer = optim.SGD(model.parameters(),
                              lr=args['lr'],
                              weight_decay=args['l2'])
    elif optimizer_opt == 'adam':
        logger.info('Adam')
        optimizer = optim.Adam(model.parameters(),
                               lr=args['lr'],
                               weight_decay=args['l2'])

    # Recall rate
    rankingScorer = GeneralScorer(
        model, preprocessors, device,
        PairBugCollate(inputHandlers, ignore_target=True),
        args['ranking_batch_size'], args['ranking_n_workers'])
    recallEstimationTrainOpt = args.get('recall_estimation_train')

    if recallEstimationTrainOpt:
        preselectListRankingTrain = PreselectListRanking(
            recallEstimationTrainOpt, args['sample_size_rr_tr'])

    recallEstimationOpt = args.get('recall_estimation')

    if recallEstimationOpt:
        preselectListRanking = PreselectListRanking(recallEstimationOpt,
                                                    args['sample_size_rr_val'])

    # LR scheduler
    lrSchedulerOpt = args.get('lr_scheduler', None)

    if lrSchedulerOpt is None:
        logger.info("Scheduler: Constant")
        lrSched = None
    elif lrSchedulerOpt["type"] == 'step':
        logger.info("Scheduler: StepLR (step:%s, decay:%f)" %
                    (lrSchedulerOpt["step_size"], args["decay"]))
        lrSched = StepLR(optimizer, lrSchedulerOpt["step_size"],
                         lrSchedulerOpt["decay"])
    elif lrSchedulerOpt["type"] == 'exp':
        logger.info("Scheduler: ExponentialLR (decay:%f)" %
                    (lrSchedulerOpt["decay"]))
        lrSched = ExponentialLR(optimizer, lrSchedulerOpt["decay"])
    elif lrSchedulerOpt["type"] == 'linear':
        logger.info(
            "Scheduler: Divide by (1 + epoch * decay) ---- (decay:%f)" %
            (lrSchedulerOpt["decay"]))

        lrDecay = lrSchedulerOpt["decay"]
        lrSched = LambdaLR(optimizer, lambda epoch: 1 /
                           (1.0 + epoch * lrDecay))
    else:
        raise ArgumentError(
            "LR Scheduler is invalid (%s). You should choose one of these: step, exp and linear "
            % pairGenType)

    # Set training functions
    def trainingIteration(engine, batch):
        engine.kk = 0
        model.train()

        optimizer.zero_grad()
        x, y = cmp_collate.to(batch, device)
        output = model(*x)
        loss = lossFn(output, y)
        loss.backward()
        optimizer.step()
        return loss, output, y

    def scoreDistanceTrans(output):
        if len(output) == 3:
            _, y_pred, y = output
        else:
            y_pred, y = output

        if lossFn == F.nll_loss:
            return torch.exp(y_pred[:, 1]), y
        elif isinstance(lossFn, (BCELoss)):
            return y_pred, y

    trainer = Engine(trainingIteration)
    trainingMetrics = {'training_loss': AverageLoss(lossFn)}

    if isinstance(lossFn, BCELoss):
        trainingMetrics['training_dist_target'] = MeanScoreDistance(
            output_transform=scoreDistanceTrans)
        trainingMetrics['training_acc'] = AccuracyWrapper(
            output_transform=thresholded_output_transform)
        trainingMetrics['training_precision'] = PrecisionWrapper(
            output_transform=thresholded_output_transform)
        trainingMetrics['training_recall'] = RecallWrapper(
            output_transform=thresholded_output_transform)
        # Add metrics to trainer
    for name, metric in trainingMetrics.items():
        metric.attach(trainer, name)

    # Set validation functions
    def validationIteration(engine, batch):
        if not hasattr(engine, 'kk'):
            engine.kk = 0

        model.eval()

        with torch.no_grad():
            x, y = cmp_collate.to(batch, device)
            y_pred = model(*x)

            return y_pred, y

    validationMetrics = {
        'validation_loss':
        LossWrapper(lossFn,
                    output_transform=lambda x: (x[0], x[0][0])
                    if x[1] is None else x)
    }

    if isinstance(lossFn, BCELoss):
        validationMetrics['validation_dist_target'] = MeanScoreDistance(
            output_transform=scoreDistanceTrans)
        validationMetrics['validation_acc'] = AccuracyWrapper(
            output_transform=thresholded_output_transform)
        validationMetrics['validation_precision'] = PrecisionWrapper(
            output_transform=thresholded_output_transform)
        validationMetrics['validation_recall'] = RecallWrapper(
            output_transform=thresholded_output_transform)

    evaluator = Engine(validationIteration)

    # Add metrics to evaluator
    for name, metric in validationMetrics.items():
        metric.attach(evaluator, name)

    # recommendation
    recommendation_fn = generateRecommendationList

    @trainer.on(Events.EPOCH_STARTED)
    def onStartEpoch(engine):
        epoch = engine.state.epoch
        logger.info("Epoch: %d" % epoch)

        if lrSched:
            lrSched.step()

        logger.info("LR: %s" % str(optimizer.param_groups[0]["lr"]))

    @trainer.on(Events.EPOCH_COMPLETED)
    def onEndEpoch(engine):
        epoch = engine.state.epoch

        logMetrics(_run, logger, engine.state.metrics, epoch)

        # Evaluate Training
        if validationLoader:
            evaluator.run(validationLoader)
            logMetrics(_run, logger, evaluator.state.metrics, epoch)

        lastEpoch = args['epochs'] - epoch == 0

        if recallEstimationTrainOpt and (epoch % args['rr_train_epoch'] == 0):
            logRankingResult(_run,
                             logger,
                             preselectListRankingTrain,
                             rankingScorer,
                             bugReportDatabase,
                             None,
                             epoch,
                             "train",
                             recommendationListfn=recommendation_fn)
            rankingScorer.free()

        if recallEstimationOpt and (epoch % args['rr_val_epoch'] == 0):
            logRankingResult(_run,
                             logger,
                             preselectListRanking,
                             rankingScorer,
                             bugReportDatabase,
                             args.get("ranking_result_file"),
                             epoch,
                             "validation",
                             recommendationListfn=recommendation_fn)
            rankingScorer.free()

        if not lastEpoch:
            training_reader.sampleNewNegExamples(model, lossNoReduction)

        if args.get('save'):
            save_by_epoch = args['save_by_epoch']

            if save_by_epoch and epoch in save_by_epoch:
                file_name, file_extension = os.path.splitext(args['save'])
                file_path = file_name + '_epoch_{}'.format(
                    epoch) + file_extension
            else:
                file_path = args['save']

            modelInfo = {
                'model': model.state_dict(),
                'params': parametersToSave
            }

            logger.info("==> Saving Model: %s" % file_path)
            torch.save(modelInfo, file_path)

    if args.get('pairs_training'):
        trainer.run(trainingLoader, max_epochs=args['epochs'])
    elif args.get('pairs_validation'):
        # Evaluate Training
        evaluator.run(validationLoader)
        logMetrics(_run, logger, evaluator.state.metrics, 0)

        if recallEstimationOpt:
            logRankingResult(_run,
                             logger,
                             preselectListRanking,
                             rankingScorer,
                             bugReportDatabase,
                             args.get("ranking_result_file"),
                             0,
                             "validation",
                             recommendationListfn=recommendation_fn)

    # Test Dataset (accuracy, recall, precision, F1)
    pair_test_dataset = args.get('pair_test_dataset')

    if pair_test_dataset is not None and len(pair_test_dataset) > 0:
        pairTestReader = PairBugDatasetReader(pair_test_dataset, preprocessors)
        testLoader = DataLoader(pairTestReader,
                                batch_size=batchSize,
                                collate_fn=cmp_collate.collate)

        if not isinstance(cmp_collate, PairBugCollate):
            raise NotImplementedError(
                'Evaluation of pairs using tanh was not implemented yet')

        logger.info("Test size: %s" % (len(testLoader.dataset)))

        testMetrics = {
            'test_accuracy':
            ignite.metrics.Accuracy(
                output_transform=thresholded_output_transform),
            'test_precision':
            ignite.metrics.Precision(
                output_transform=thresholded_output_transform),
            'test_recall':
            ignite.metrics.Recall(
                output_transform=thresholded_output_transform),
            'test_predictions':
            PredictionCache(),
        }
        test_evaluator = Engine(validationIteration)

        # Add metrics to evaluator
        for name, metric in testMetrics.items():
            metric.attach(test_evaluator, name)

        test_evaluator.run(testLoader)

        for metricName, metricValue in test_evaluator.state.metrics.items():
            metric = testMetrics[metricName]

            if isinstance(metric, ignite.metrics.Accuracy):
                logger.info({
                    'type': 'metric',
                    'label': metricName,
                    'value': metricValue,
                    'epoch': None,
                    'correct': metric._num_correct,
                    'total': metric._num_examples
                })
                _run.log_scalar(metricName, metricValue)
            elif isinstance(metric,
                            (ignite.metrics.Precision, ignite.metrics.Recall)):
                logger.info({
                    'type': 'metric',
                    'label': metricName,
                    'value': metricValue,
                    'epoch': None,
                    'tp': metric._true_positives.item(),
                    'total_positive': metric._positives.item()
                })
                _run.log_scalar(metricName, metricValue)
            elif isinstance(metric, ConfusionMatrix):
                acc = cmAccuracy(metricValue)
                prec = cmPrecision(metricValue, False)
                recall = cmRecall(metricValue, False)
                f1 = 2 * (prec * recall) / (prec + recall + 1e-15)

                logger.info({
                    'type':
                    'metric',
                    'label':
                    metricName,
                    'accuracy':
                    np.float(acc),
                    'precision':
                    prec.cpu().numpy().tolist(),
                    'recall':
                    recall.cpu().numpy().tolist(),
                    'f1':
                    f1.cpu().numpy().tolist(),
                    'confusion_matrix':
                    metricValue.cpu().numpy().tolist(),
                    'epoch':
                    None
                })

                _run.log_scalar('test_f1', f1[1])
            elif isinstance(metric, PredictionCache):
                logger.info({
                    'type': 'metric',
                    'label': metricName,
                    'predictions': metric.predictions
                })

    # Calculate recall rate
    recallRateOpt = args.get('recall_rate', {'type': 'none'})
    if recallRateOpt['type'] != 'none':
        if recallRateOpt['type'] == 'sun2011':
            logger.info("Calculating recall rate: {}".format(
                recallRateOpt['type']))
            recallRateDataset = BugDataset(recallRateOpt['dataset'])

            rankingClass = SunRanking(bugReportDatabase, recallRateDataset,
                                      recallRateOpt['window'])
            # We always group all bug reports by master in the results in the sun 2011 methodology
            group_by_master = True
        elif recallRateOpt['type'] == 'deshmukh':
            logger.info("Calculating recall rate: {}".format(
                recallRateOpt['type']))
            recallRateDataset = BugDataset(recallRateOpt['dataset'])
            rankingClass = DeshmukhRanking(bugReportDatabase,
                                           recallRateDataset)
            group_by_master = recallRateOpt['group_by_master']
        else:
            raise ArgumentError(
                "recall_rate.type is invalid (%s). You should choose one of these: step, exp and linear "
                % recallRateOpt['type'])

        logRankingResult(_run,
                         logger,
                         rankingClass,
                         rankingScorer,
                         bugReportDatabase,
                         recallRateOpt["result_file"],
                         0,
                         None,
                         group_by_master,
                         recommendationListfn=recommendation_fn)
Пример #17
0
class BPR_Factorizer(object):
    def __init__(self, opt):
        self.opt = opt
        self.clip = opt.get('grad_clip')
        self.use_cuda = opt.get('use_cuda')
        self.batch_size_test = opt.get('batch_size_test')

        # self.metron = MetronAtK(top_k=opt['metric_topk'])
        self.criterion = BCEWithLogitsLoss(size_average=False)

        self.model = None
        self.optimizer = None
        self.scheduler = None
        self.l2_penalty = None

        self.param_grad = None
        self.optim_status = None

        self.prev_param = None
        self.param = None

        # is the factorizer for assumed update
        self.is_assumed = False

        self._train_step_idx = None
        self._train_episode_idx = None

    # @profile
    def copy(self, new_factorizer):
        """Return a new copy of factorizer

        # Note: directly using deepcopy wont copy factorizer.scheduler correctly
                the gradient of self.model is not copied!
        """
        self.train_step_idx = new_factorizer.train_step_idx
        self.param = new_factorizer.param
        self.model.load_state_dict(new_factorizer.model.state_dict())
        self.optimizer = use_optimizer(self.model, self.opt)
        self.optimizer.load_state_dict(new_factorizer.optimizer.state_dict())

        self.scheduler = ExponentialLR(self.optimizer,
                                       gamma=self.opt['lr_exp_decay'],
                                       last_epoch=self.scheduler.last_epoch)

    @property
    def delta_param(self):
        """update of parameter, for SAC regularizer

        return:
            list of pytorch tensor
        """
        delta_param = list()
        for i, (prev_p, p) in enumerate(zip(self.prev_param, self.param)):
            delta_param[i] = p - prev_p
        return delta_param

    @property
    def train_step_idx(self):
        return self._train_step_idx

    @train_step_idx.setter
    def train_step_idx(self, new_step_idx):
        self._train_step_idx = new_step_idx

    @property
    def train_episode_idx(self):
        return self._train_episode_idx

    @train_episode_idx.setter
    def train_episode_idx(self, new_episode_idx):
        self._train_episode_idx = new_episode_idx

    def get_grad_norm(self):
        assert hasattr(self, 'model')
        return get_grad_norm(self.model)

    def update(self, sampler, l2_lambda):
        if (self.train_step_idx >
                0) and (self.train_step_idx % sampler.num_batches_train
                        == 0):  # sampler.get_num_batch_per_epoch('train')
            self.scheduler.step()
            print('\tfactorizer lr decay ...')
        self.train_step_idx += 1

        self.model.train()
        self.optimizer.zero_grad()
def build_scheduler(config: dict, optimizer: Optimizer, scheduler_mode: str,
                    hidden_size: int = 0) \
        -> (Optional[_LRScheduler], Optional[str]):
    """
    Create a learning rate scheduler if specified in config and
    determine when a scheduler step should be executed.

    Current options:
        - "plateau": see `torch.optim.lr_scheduler.ReduceLROnPlateau`
        - "decaying": see `torch.optim.lr_scheduler.StepLR`
        - "exponential": see `torch.optim.lr_scheduler.ExponentialLR`
        - "noam": see `SignProdJoey.transformer.NoamScheduler`

    If no scheduler is specified, returns (None, None) which will result in
    a constant learning rate.

    :param config: training configuration
    :param optimizer: optimizer for the scheduler, determines the set of
        parameters which the scheduler sets the learning rate for
    :param scheduler_mode: "min" or "max", depending on whether the validation
        score should be minimized or maximized.
        Only relevant for "plateau".
    :param hidden_size: encoder hidden size (required for NoamScheduler)
    :return:
        - scheduler: scheduler object,
        - scheduler_step_at: either "validation" or "epoch"
    """
    scheduler, scheduler_step_at = None, None
    if "scheduling" in config.keys() and \
            config["scheduling"]:
        if config["scheduling"].lower() == "plateau":
            # learning rate scheduler
            scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                          mode=scheduler_mode,
                                          verbose=False,
                                          threshold_mode='abs',
                                          threshold=1e-8,
                                          factor=config.get(
                                              "decrease_factor", 0.1),
                                          patience=config.get("patience", 10))
            # scheduler step is executed after every validation
            scheduler_step_at = "validation"
        elif config["scheduling"].lower() == "decaying":
            scheduler = StepLR(optimizer=optimizer,
                               step_size=config.get("decaying_step_size", 1))
            # scheduler step is executed after every epoch
            scheduler_step_at = "epoch"
        elif config["scheduling"].lower() == "exponential":
            scheduler = ExponentialLR(optimizer=optimizer,
                                      gamma=config.get("decrease_factor",
                                                       0.99))
            # scheduler step is executed after every epoch
            scheduler_step_at = "epoch"
        elif config["scheduling"].lower() == "noam":
            factor = config.get("learning_rate_factor", 1)
            warmup = config.get("learning_rate_warmup", 4000)
            scheduler = NoamScheduler(hidden_size=hidden_size,
                                      factor=factor,
                                      warmup=warmup,
                                      optimizer=optimizer)

            scheduler_step_at = "step"
    return scheduler, scheduler_step_at
            model = torch.load(model_params['model_path'],
                               lambda storage, loc: storage)

    # Pass only the trainable parameters to the optimizer, otherwise pyTorch throws an error
    # relevant to Transfer learning with fixed features

    optimizer = train_control['optimizer'](filter(lambda p: p.requires_grad,
                                                  model.parameters()),
                                           **optimizer_params)

    # Initiate Scheduler

    if (train_control['lr_scheduler_type'] == 'step'):
        scheduler = StepLR(optimizer, **train_control['step_scheduler_args'])
    elif (train_control['lr_scheduler_type'] == 'exp'):
        scheduler = ExponentialLR(optimizer,
                                  **train_control['exp_scheduler_args'])
    elif (train_control['lr_scheduler_type'] == 'plateau'):
        scheduler = ReduceLROnPlateau(
            optimizer, **train_control['plateau_scheduler_args'])
    else:
        scheduler = StepLR(optimizer, step_size=100, gamma=1)

    if model_params['pytorch_device'] == 'gpu':
        with torch.cuda.device(model_params['cuda_device']):
            model_trainer = ModelTrainer(model,
                                         train_dataset_loader,
                                         valid_dataset_loader,
                                         test_dataset_loader,
                                         model_params['model_path'],
                                         optimizer=optimizer,
                                         optimizer_args=optimizer_params,
Пример #20
0
    def select_opt_schr(self):

        self.g_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr,
            (self.beta1, self.beta2))
        self.ds_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.D_s.parameters()),
            self.d_lr, (self.beta1, self.beta2))
        self.dt_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.D_t.parameters()),
            self.d_lr, (self.beta1, self.beta2))
        if self.lr_schr == 'const':
            self.g_lr_scher = StepLR(self.g_optimizer,
                                     step_size=10000,
                                     gamma=1)
            self.ds_lr_scher = StepLR(self.ds_optimizer,
                                      step_size=10000,
                                      gamma=1)
            self.dt_lr_scher = StepLR(self.dt_optimizer,
                                      step_size=10000,
                                      gamma=1)
        elif self.lr_schr == 'step':
            self.g_lr_scher = StepLR(self.g_optimizer,
                                     step_size=500,
                                     gamma=0.98)
            self.ds_lr_scher = StepLR(self.ds_optimizer,
                                      step_size=500,
                                      gamma=0.98)
            self.dt_lr_scher = StepLR(self.dt_optimizer,
                                      step_size=500,
                                      gamma=0.98)
        elif self.lr_schr == 'exp':
            self.g_lr_scher = ExponentialLR(self.g_optimizer, gamma=0.9999)
            self.ds_lr_scher = ExponentialLR(self.ds_optimizer, gamma=0.9999)
            self.dt_lr_scher = ExponentialLR(self.dt_optimizer, gamma=0.9999)
        elif self.lr_schr == 'multi':
            self.g_lr_scher = MultiStepLR(self.g_optimizer, [10000, 30000],
                                          gamma=0.3)
            self.ds_lr_scher = MultiStepLR(self.ds_optimizer, [10000, 30000],
                                           gamma=0.3)
            self.dt_lr_scher = MultiStepLR(self.dt_optimizer, [10000, 30000],
                                           gamma=0.3)
        else:
            self.g_lr_scher = ReduceLROnPlateau(self.g_optimizer,
                                                mode='min',
                                                factor=self.lr_decay,
                                                patience=100,
                                                threshold=0.0001,
                                                threshold_mode='rel',
                                                cooldown=0,
                                                min_lr=1e-10,
                                                eps=1e-08,
                                                verbose=True)
            self.ds_lr_scher = ReduceLROnPlateau(self.ds_optimizer,
                                                 mode='min',
                                                 factor=self.lr_decay,
                                                 patience=100,
                                                 threshold=0.0001,
                                                 threshold_mode='rel',
                                                 cooldown=0,
                                                 min_lr=1e-10,
                                                 eps=1e-08,
                                                 verbose=True)
            self.dt_lr_scher = ReduceLROnPlateau(self.dt_optimizer,
                                                 mode='min',
                                                 factor=self.lr_decay,
                                                 patience=100,
                                                 threshold=0.0001,
                                                 threshold_mode='rel',
                                                 cooldown=0,
                                                 min_lr=1e-10,
                                                 eps=1e-08,
                                                 verbose=True)
Пример #21
0
                    num_layers=NUM_LAYERS).to(device)
    model.load_state_dict(torch.load(MODEL))
    for params in model.parameters():
        params.requires_grad = True
if MODE == 4:
    model = Network(num_features=NUM_FEATURES,
                    num_layers=NUM_LAYERS).to(device)
    for params in model.parameters():
        params.requires_grad = True

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                    model.parameters()),
                             lr=LEARNING_RATE)
scheduler = ExponentialLR(optimizer, LAMBDA)
epoch, estop, maxval, maxind = 0, False, 0, 0

while epoch < NUM_EPOCHS and not estop:
    dataloader = DataLoader(dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=True,
                            drop_last=False)
    if WAIT and epoch > 4:
        scheduler.step()
    if not WAIT:
        scheduler.step()

    for i_batch, batch in enumerate(dataloader):
        # Forward
        optimizer.zero_grad()
Пример #22
0
def train(data_path, neg_batch_size, batch_size, shuffle, num_workers, nb_epochs, embedding_dim, hidden_dim, relation_dim, gpu, use_cuda,patience, freeze, validate_every, hops, lr, entdrop, reldrop, scoredrop, l3_reg, model_name, decay, ls, load_from, outfile, do_batch_norm, valid_data_path=None):
    print('Loading entities and relations')
    kg_type = 'full'
    if 'half' in hops:
        kg_type = 'half'
    checkpoint_file = '../../pretrained_models/embeddings/ComplEx_fbwq_' + kg_type + '/checkpoint_best.pt'
    print('Loading kg embeddings from', checkpoint_file)
    kge_checkpoint = load_checkpoint(checkpoint_file)
    kge_model = KgeModel.create_from(kge_checkpoint)
    kge_model.eval()
    e = getEntityEmbeddings(kge_model, hops)

    print('Loaded entities and relations')

    entity2idx, idx2entity, embedding_matrix = prepare_embeddings(e)
    data = process_text_file(data_path, split=False)
    print('Train file processed, making dataloader')
    # word2ix,idx2word, max_len = get_vocab(data)
    # hops = str(num_hops)
    device = torch.device(gpu if use_cuda else "cpu")
    dataset = DatasetMetaQA(data, e, entity2idx)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    print('Creating model...')
    model = RelationExtractor(embedding_dim=embedding_dim, num_entities = len(idx2entity), relation_dim=relation_dim, pretrained_embeddings=embedding_matrix, freeze=freeze, device=device, entdrop = entdrop, reldrop = reldrop, scoredrop = scoredrop, l3_reg = l3_reg, model = model_name, ls = ls, do_batch_norm=do_batch_norm)
    print('Model created!')
    if load_from != '':
        # model.load_state_dict(torch.load("checkpoints/roberta_finetune/" + load_from + ".pt"))
        fname = "checkpoints/roberta_finetune/" + load_from + ".pt"
        model.load_state_dict(torch.load(fname, map_location=lambda storage, loc: storage))
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = ExponentialLR(optimizer, decay)
    optimizer.zero_grad()
    best_score = -float("inf")
    best_model = model.state_dict()
    no_update = 0
    # time.sleep(10)
    for epoch in range(nb_epochs):
        phases = []
        for i in range(validate_every):
            phases.append('train')
        phases.append('valid')
        for phase in phases:
            if phase == 'train':
                model.train()
                # model.apply(set_bn_eval)
                loader = tqdm(data_loader, total=len(data_loader), unit="batches")
                running_loss = 0
                for i_batch, a in enumerate(loader):
                    model.zero_grad()
                    question_tokenized = a[0].to(device)
                    attention_mask = a[1].to(device)
                    positive_head = a[2].to(device)
                    positive_tail = a[3].to(device)    
                    loss = model(question_tokenized=question_tokenized, attention_mask=attention_mask, p_head=positive_head, p_tail=positive_tail)
                    loss.backward()
                    optimizer.step()
                    running_loss += loss.item()
                    loader.set_postfix(Loss=running_loss/((i_batch+1)*batch_size), Epoch=epoch)
                    loader.set_description('{}/{}'.format(epoch, nb_epochs))
                    loader.update()
                
                scheduler.step()

            elif phase=='valid':
                model.eval()
                eps = 0.0001
                answers, score = validate_v2(model=model, data_path= valid_data_path, entity2idx=entity2idx, train_dataloader=dataset, device=device, model_name=model_name)
                if score > best_score + eps:
                    best_score = score
                    no_update = 0
                    best_model = model.state_dict()
                    print(hops + " hop Validation accuracy (no relation scoring) increased from previous epoch", score)
                    # writeToFile(answers, 'results_' + model_name + '_' + hops + '.txt')
                    # torch.save(best_model, "checkpoints/roberta_finetune/best_score_model.pt")
                    # torch.save(best_model, "checkpoints/roberta_finetune/" + outfile + ".pt")
                elif (score < best_score + eps) and (no_update < patience):
                    no_update +=1
                    print("Validation accuracy decreases to %f from %f, %d more epoch to check"%(score, best_score, patience-no_update))
                elif no_update == patience:
                    print("Model has exceed patience. Saving best model and exiting")
                    # torch.save(best_model, "checkpoints/roberta_finetune/best_score_model.pt")
                    # torch.save(best_model, "checkpoints/roberta_finetune/" + outfile + ".pt")
                    exit()
                if epoch == nb_epochs-1:
                    print("Final Epoch has reached. Stoping and saving model.")
                    # torch.save(best_model, "checkpoints/roberta_finetune/best_score_model.pt")
                    # torch.save(best_model, "checkpoints/roberta_finetune/" + outfile + ".pt")
                    exit()
Пример #23
0
    def __call__(self, config):
        if not os.path.exists(config.file):
            os.mkdir(config.file)
        if config.preprocess or not os.path.exists(config.vocab):
            print("Preprocess the corpus")
            pos_train = Corpus.load(config.fptrain, [1, 4], config.pos)
            dep_train = Corpus.load(config.ftrain)
            pos_dev = Corpus.load(config.fpdev, [1, 4])
            dep_dev = Corpus.load(config.fdev)
            pos_test = Corpus.load(config.fptest, [1, 4])
            dep_test = Corpus.load(config.ftest)
            print("Create the vocab")
            vocab = Vocab.from_corpora(pos_train, dep_train, 2)
            vocab.read_embeddings(Embedding.load(config.fembed))
            print("Load the dataset")
            pos_trainset = TextDataset(vocab.numericalize(pos_train, False),
                                       config.buckets)
            dep_trainset = TextDataset(vocab.numericalize(dep_train),
                                       config.buckets)
            pos_devset = TextDataset(vocab.numericalize(pos_dev, False),
                                     config.buckets)
            dep_devset = TextDataset(vocab.numericalize(dep_dev),
                                     config.buckets)
            pos_testset = TextDataset(vocab.numericalize(pos_test, False),
                                      config.buckets)
            dep_testset = TextDataset(vocab.numericalize(dep_test),
                                      config.buckets)
            torch.save(vocab, config.vocab)
            torch.save(pos_trainset, os.path.join(config.file, 'pos_trainset'))
            torch.save(dep_trainset, os.path.join(config.file, 'dep_trainset'))
            torch.save(pos_devset, os.path.join(config.file, 'pos_devset'))
            torch.save(dep_devset, os.path.join(config.file, 'dep_devset'))
            torch.save(pos_testset, os.path.join(config.file, 'pos_testset'))
            torch.save(dep_testset, os.path.join(config.file, 'dep_testset'))
        else:
            print("Load the vocab")
            vocab = torch.load(config.vocab)
            print("Load the datasets")
            pos_trainset = torch.load(os.path.join(config.file,
                                                   'pos_trainset'))
            dep_trainset = torch.load(os.path.join(config.file,
                                                   'dep_trainset'))
            pos_devset = torch.load(os.path.join(config.file, 'pos_devset'))
            dep_devset = torch.load(os.path.join(config.file, 'dep_devset'))
            pos_testset = torch.load(os.path.join(config.file, 'pos_testset'))
            dep_testset = torch.load(os.path.join(config.file, 'dep_testset'))
        config.update({
            'n_words': vocab.n_init,
            'n_chars': vocab.n_chars,
            'n_pos_tags': vocab.n_pos_tags,
            'n_dep_tags': vocab.n_dep_tags,
            'n_rels': vocab.n_rels,
            'pad_index': vocab.pad_index,
            'unk_index': vocab.unk_index
        })
        # set the data loaders
        pos_train_loader = batchify(
            pos_trainset, config.pos_batch_size // config.update_steps, True)
        dep_train_loader = batchify(dep_trainset,
                                    config.batch_size // config.update_steps,
                                    True)
        pos_dev_loader = batchify(pos_devset, config.pos_batch_size)
        dep_dev_loader = batchify(dep_devset, config.batch_size)
        pos_test_loader = batchify(pos_testset, config.pos_batch_size)
        dep_test_loader = batchify(dep_testset, config.batch_size)

        print(vocab)
        print(f"{'pos_train:':10} {len(pos_trainset):7} sentences in total, "
              f"{len(pos_train_loader):4} batches provided")
        print(f"{'dep_train:':10} {len(dep_trainset):7} sentences in total, "
              f"{len(dep_train_loader):4} batches provided")
        print(f"{'pos_dev:':10} {len(pos_devset):7} sentences in total, "
              f"{len(pos_dev_loader):4} batches provided")
        print(f"{'dep_dev:':10} {len(dep_devset):7} sentences in total, "
              f"{len(dep_dev_loader):4} batches provided")
        print(f"{'pos_test:':10} {len(pos_testset):7} sentences in total, "
              f"{len(pos_test_loader):4} batches provided")
        print(f"{'dep_test:':10} {len(dep_testset):7} sentences in total, "
              f"{len(dep_test_loader):4} batches provided")

        print("Create the model")
        parser = BiaffineParser(config, vocab.embed).to(config.device)
        print(f"{parser}\n")

        model = Model(config, vocab, parser)

        total_time = timedelta()
        best_e, best_metric = 1, AttachmentMethod()
        model.optimizer = Adam(model.parser.parameters(), config.lr,
                               (config.mu, config.nu), config.epsilon)
        model.scheduler = ExponentialLR(model.optimizer,
                                        config.decay**(1 / config.decay_steps))

        for epoch in range(1, config.epochs + 1):
            start = datetime.now()
            # train one epoch and update the parameters
            model.train(pos_train_loader, dep_train_loader)
            print(f"Epoch {epoch} / {config.epochs}:")
            lp, ld, mp, mdt, mdp = model.evaluate(None, dep_train_loader)
            print(f"{'train:':6} LP: {lp:.4f} LD: {ld:.4f} {mp} {mdt} {mdp}")
            lp, ld, mp, mdt, dev_m = model.evaluate(pos_dev_loader,
                                                    dep_dev_loader)
            print(f"{'dev:':6} LP: {lp:.4f} LD: {ld:.4f} {mp} {mdt} {dev_m}")
            lp, ld, mp, mdt, mdp = model.evaluate(pos_test_loader,
                                                  dep_test_loader)
            print(f"{'test:':6} LP: {lp:.4f} LD: {ld:.4f} {mp} {mdt} {mdp}")

            t = datetime.now() - start
            # save the model if it is the best so far
            if dev_m > best_metric and epoch > config.patience:
                best_e, best_metric = epoch, dev_m
                model.parser.save(config.model)
                print(f"{t}s elapsed (saved)\n")
            else:
                print(f"{t}s elapsed\n")
            total_time += t
            if epoch - best_e >= config.patience:
                break
        model.parser = BiaffineParser.load(config.model)
        lp, ld, mp, mdt, mdp = model.evaluate(pos_test_loader, dep_test_loader)

        print(f"max score of dev is {best_metric.score:.2%} at epoch {best_e}")
        print(f"the score of test at epoch {best_e} is {mdp.score:.2%}")
        print(f"average time of each epoch is {total_time / epoch}s")
        print(f"{total_time}s elapsed")
Пример #24
0
    def train(self,
              train,
              dev,
              test,
              buckets=32,
              batch_size=5000,
              lr=2e-3,
              mu=.9,
              nu=.9,
              epsilon=1e-12,
              clip=5.0,
              decay=.75,
              decay_steps=5000,
              epochs=5000,
              patience=100,
              verbose=True,
              **kwargs):
        args = self.args.update(locals())
        init_logger(logger, verbose=args.verbose)

        self.transform.train()
        if dist.is_initialized():
            args.batch_size = args.batch_size // dist.get_world_size()
        logger.info("Loading the data")
        train = Dataset(self.transform, args.train, **args)
        dev = Dataset(self.transform, args.dev)
        test = Dataset(self.transform, args.test)
        logger.info("Building the datasets")
        train.build(args.batch_size, args.buckets, True, dist.is_initialized())
        logger.info("train built")
        dev.build(args.batch_size, args.buckets)
        logger.info("dev built")
        test.build(args.batch_size, args.buckets)
        logger.info(
            f"\n{'train:':6} {train}\n{'dev:':6} {dev}\n{'test:':6} {test}\n")

        logger.info(f"{self.model}\n")
        if dist.is_initialized():
            self.model = DDP(self.model,
                             device_ids=[args.local_rank],
                             find_unused_parameters=True)
        self.optimizer = Adam(self.model.parameters(), args.lr,
                              (args.mu, args.nu), args.epsilon)
        self.scheduler = ExponentialLR(self.optimizer,
                                       args.decay**(1 / args.decay_steps))

        elapsed = timedelta()
        best_e, best_metric = 1, Metric()

        for epoch in range(1, args.epochs + 1):
            start = datetime.now()

            logger.info(f"Epoch {epoch} / {args.epochs}:")
            self._train(train.loader)
            loss, dev_metric = self._evaluate(dev.loader)
            logger.info(f"{'dev:':6} - loss: {loss:.4f} - {dev_metric}")
            loss, test_metric = self._evaluate(test.loader)
            logger.info(f"{'test:':6} - loss: {loss:.4f} - {test_metric}")

            t = datetime.now() - start
            # save the model if it is the best so far
            if dev_metric > best_metric:
                best_e, best_metric = epoch, dev_metric
                if is_master():
                    self.save(args.path)
                logger.info(f"{t}s elapsed (saved)\n")
            else:
                logger.info(f"{t}s elapsed\n")
            elapsed += t
            if epoch - best_e >= args.patience:
                break
        loss, metric = self.load(**args)._evaluate(test.loader)

        logger.info(f"Epoch {best_e} saved")
        logger.info(f"{'dev:':6} - {best_metric}")
        logger.info(f"{'test:':6} - {metric}")
        logger.info(f"{elapsed}s elapsed, {elapsed / epoch}s/epoch")
def test_warmup_against_original_schedule(lr_scheduler_type: LRSchedulerType,
                                          warmup_epochs: int) -> None:
    """
    Tests if LR scheduler with warmup matches the Pytorch implementation after the warmup stage is completed.
    """
    config = DummyModel(num_epochs=6,
                        l_rate=1e-2,
                        l_rate_scheduler=lr_scheduler_type,
                        l_rate_exponential_gamma=0.9,
                        l_rate_step_gamma=0.9,
                        l_rate_step_step_size=2,
                        l_rate_multi_step_gamma=0.9,
                        l_rate_multi_step_milestones=[3, 5, 7],
                        l_rate_polynomial_gamma=0.9,
                        l_rate_warmup=LRWarmUpType.Linear
                        if warmup_epochs > 0 else LRWarmUpType.NoWarmUp,
                        l_rate_warmup_epochs=warmup_epochs)
    # create lr scheduler
    lr_scheduler, optimizer1 = _create_lr_scheduler_and_optimizer(config)

    original_scheduler: Optional[_LRScheduler] = None
    optimizer2 = _create_dummy_optimizer(config)
    # This mimics the code in SchedulerWithWarmUp.get_scheduler and must be in sync
    if lr_scheduler_type == LRSchedulerType.Exponential:
        original_scheduler = ExponentialLR(
            optimizer=optimizer2, gamma=config.l_rate_exponential_gamma)
    elif lr_scheduler_type == LRSchedulerType.Step:
        original_scheduler = StepLR(optimizer=optimizer2,
                                    step_size=config.l_rate_step_step_size,
                                    gamma=config.l_rate_step_gamma)
    elif lr_scheduler_type == LRSchedulerType.Cosine:
        original_scheduler = CosineAnnealingLR(optimizer2,
                                               T_max=config.num_epochs,
                                               eta_min=config.min_l_rate)
    elif lr_scheduler_type == LRSchedulerType.MultiStep:
        assert config.l_rate_multi_step_milestones is not None  # for mypy
        original_scheduler = MultiStepLR(
            optimizer=optimizer2,
            milestones=config.l_rate_multi_step_milestones,
            gamma=config.l_rate_multi_step_gamma)
    elif lr_scheduler_type == LRSchedulerType.Polynomial:
        x = config.min_l_rate / config.l_rate
        polynomial_decay: Any = lambda epoch: (1 - x) * ((1. - float(
            epoch) / config.num_epochs)**config.l_rate_polynomial_gamma) + x
        original_scheduler = LambdaLR(optimizer=optimizer2,
                                      lr_lambda=polynomial_decay)
    else:
        raise ValueError("Scheduler has not been added to this test.")

    expected_lr_list = []
    if warmup_epochs == 0:
        pass
    elif warmup_epochs == 3:
        # For the first config.l_rate_warmup_epochs, the learning rate is lower than the initial learning rate by a
        # linear factor
        expected_lr_list.extend([f * config.l_rate for f in [0.25, 0.5, 0.75]])
    else:
        raise NotImplementedError()
    expected_lr_list.extend(
        enumerate_scheduler(original_scheduler,
                            config.num_epochs - warmup_epochs))
    print(f"Expected schedule with warmup: {expected_lr_list}")

    lr_with_warmup_scheduler = enumerate_scheduler(lr_scheduler,
                                                   config.num_epochs)
    print(f"Actual schedule: {lr_with_warmup_scheduler}")

    if ((lr_scheduler_type == LRSchedulerType.Polynomial
         or lr_scheduler_type == LRSchedulerType.Cosine)
            and warmup_epochs > 0):
        # Polynomial and Cosine scheduler will be squashed in time because the number of epochs is reduced
        # (both schedulers take a "length of training" argument, and that is now shorter). Skip comparing those.
        pass
    else:
        assert np.allclose(lr_with_warmup_scheduler,
                           expected_lr_list,
                           rtol=1e-5)
Пример #26
0
def main():
    parser = get_parser()
    args = parser.parse_args()

    if not torch.cuda.is_available():
        raise ValueError(
            "The script requires CUDA support, but CUDA not available")

    args.rank = -1
    args.world_size = 1

    if args.model_parallel:
        args.deepspeed = False
        cfg = {
            "microbatches": args.num_microbatches,
            "placement_strategy": args.placement_strategy,
            "pipeline": args.pipeline,
            "optimize": args.optimize,
            "partitions": args.num_partitions,
            "horovod": args.horovod,
            "ddp": args.ddp,
        }

        smp.init(cfg)
        torch.cuda.set_device(smp.local_rank())
        args.rank = smp.dp_rank()
        args.world_size = smp.size()
    else:
        # initialize deepspeed
        print(f"args.deepspeed : {args.deepspeed}")
        deepspeed_utils.init_deepspeed(args.deepspeed)
        if deepspeed_utils.is_root_worker():
            args.rank = 0

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed + args.rank)
        np.random.seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)

    # args.LEARNING_RATE = args.LEARNING_RATE * float(args.world_size)

    cudnn.deterministic = True

    if cudnn.deterministic:
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    args.kwargs = {'num_workers': args.num_worker, 'pin_memory': True}

    device = torch.device("cuda")

    logger.debug(f"args.image_folder : {args.image_folder}")
    logger.debug(f"args.rank : {args.rank}")

    ## SageMaker
    try:
        if os.environ.get('SM_MODEL_DIR') is not None:
            args.model_dir = os.environ.get('SM_MODEL_DIR')
            #             args.output_dir = os.environ.get('SM_OUTPUT_DATA_DIR')
            args.image_folder = os.environ.get('SM_CHANNEL_TRAINING')
    except:
        logger.debug("not SageMaker")
        pass

    IMAGE_SIZE = args.image_size
    IMAGE_PATH = args.image_folder

    EPOCHS = args.EPOCHS
    BATCH_SIZE = args.BATCH_SIZE
    LEARNING_RATE = args.LEARNING_RATE
    LR_DECAY_RATE = args.LR_DECAY_RATE

    NUM_TOKENS = args.NUM_TOKENS
    NUM_LAYERS = args.NUM_LAYERS
    NUM_RESNET_BLOCKS = args.NUM_RESNET_BLOCKS
    SMOOTH_L1_LOSS = args.SMOOTH_L1_LOSS
    EMB_DIM = args.EMB_DIM
    HID_DIM = args.HID_DIM
    KL_LOSS_WEIGHT = args.KL_LOSS_WEIGHT

    STARTING_TEMP = args.STARTING_TEMP
    TEMP_MIN = args.TEMP_MIN
    ANNEAL_RATE = args.ANNEAL_RATE

    NUM_IMAGES_SAVE = args.NUM_IMAGES_SAVE

    #     transform = Compose(
    #         [
    #             RandomResizedCrop(args.image_size, args.image_size),
    #             OneOf(
    #                 [
    #                     IAAAdditiveGaussianNoise(),
    #                     GaussNoise(),
    #                 ],
    #                 p=0.2
    #             ),
    #             VerticalFlip(p=0.5),
    #             OneOf(
    #                 [
    #                     MotionBlur(p=.2),
    #                     MedianBlur(blur_limit=3, p=0.1),
    #                     Blur(blur_limit=3, p=0.1),
    #                 ],
    #                 p=0.2
    #             ),
    #             OneOf(
    #                 [
    #                     CLAHE(clip_limit=2),
    #                     IAASharpen(),
    #                     IAAEmboss(),
    #                     RandomBrightnessContrast(),
    #                 ],
    #                 p=0.3
    #             ),
    #             HueSaturationValue(p=0.3),
    # #             Normalize(
    # #                 mean=[0.485, 0.456, 0.406],
    # #                 std=[0.229, 0.224, 0.225],
    # #             )
    #         ],
    #         p=1.0
    #     )

    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize(IMAGE_SIZE),
        T.CenterCrop(IMAGE_SIZE),
        T.ToTensor()
    ])

    sampler = None
    dl = None

    # data
    logger.debug(f"IMAGE_PATH : {IMAGE_PATH}")
    #     ds = AlbumentationImageDataset(
    #         IMAGE_PATH,
    #         transform=transform,
    #         args=args
    #     )
    ds = ImageFolder(
        IMAGE_PATH,
        transform=transform,
    )

    if args.model_parallel and (args.ddp
                                or args.horovod) and smp.dp_size() > 1:
        partitions_dict = {
            f"{i}": 1 / smp.dp_size()
            for i in range(smp.dp_size())
        }
        ds = SplitDataset(ds, partitions=partitions_dict)
        ds.select(f"{smp.dp_rank()}")

    dl = DataLoader(ds,
                    BATCH_SIZE,
                    shuffle=True,
                    drop_last=args.model_parallel,
                    **args.kwargs)

    vae_params = dict(image_size=IMAGE_SIZE,
                      num_layers=NUM_LAYERS,
                      num_tokens=NUM_TOKENS,
                      codebook_dim=EMB_DIM,
                      hidden_dim=HID_DIM,
                      num_resnet_blocks=NUM_RESNET_BLOCKS)

    vae = DiscreteVAE(**vae_params,
                      smooth_l1_loss=SMOOTH_L1_LOSS,
                      kl_div_loss_weight=KL_LOSS_WEIGHT).to(device)
    # optimizer

    opt = Adam(vae.parameters(), lr=LEARNING_RATE)
    sched = ExponentialLR(optimizer=opt, gamma=LR_DECAY_RATE)

    if args.model_parallel:
        import copy
        dummy_codebook = copy.deepcopy(vae.codebook)
        dummy_decoder = copy.deepcopy(vae.decoder)

        vae = smp.DistributedModel(vae)
        scaler = smp.amp.GradScaler()
        opt = smp.DistributedOptimizer(opt)

        if args.partial_checkpoint:
            args.checkpoint = smp.load(args.partial_checkpoint, partial=True)
            vae.load_state_dict(args.checkpoint["model_state_dict"])
            opt.load_state_dict(args.checkpoint["optimizer_state_dict"])
        elif args.full_checkpoint:
            args.checkpoint = smp.load(args.full_checkpoint, partial=False)
            vae.load_state_dict(args.checkpoint["model_state_dict"])
            opt.load_state_dict(args.checkpoint["optimizer_state_dict"])

    assert len(ds) > 0, 'folder does not contain any images'

    if (not args.model_parallel) and args.rank == 0:
        print(f'{len(ds)} images found for training')

        # weights & biases experiment tracking

        #         import wandb

        model_config = dict(num_tokens=NUM_TOKENS,
                            smooth_l1_loss=SMOOTH_L1_LOSS,
                            num_resnet_blocks=NUM_RESNET_BLOCKS,
                            kl_loss_weight=KL_LOSS_WEIGHT)

#         run = wandb.init(
#             project = 'dalle_train_vae',
#             job_type = 'train_model',
#             config = model_config
#         )

    def save_model(path):
        if not args.rank == 0:
            return

        save_obj = {'hparams': vae_params, 'weights': vae.state_dict()}

        torch.save(save_obj, path)

    # distribute with deepspeed
    if not args.model_parallel:
        deepspeed_utils.check_batch_size(BATCH_SIZE)
        deepspeed_config = {'train_batch_size': BATCH_SIZE}

        (distr_vae, opt, dl, sched) = deepspeed_utils.maybe_distribute(
            args=args,
            model=vae,
            optimizer=opt,
            model_parameters=vae.parameters(),
            training_data=ds if args.deepspeed else dl,
            lr_scheduler=sched,
            config_params=deepspeed_config,
        )

    try:
        # Rubik: Define smp.step. Return any tensors needed outside.
        @smp.step
        def train_step(vae, images, temp):
            #             logger.debug(f"args.amp : {args.amp}")
            with autocast(enabled=(args.amp > 0)):
                loss, recons = vae(images,
                                   return_loss=True,
                                   return_recons=True,
                                   temp=temp)

            scaled_loss = scaler.scale(loss) if args.amp else loss
            vae.backward(scaled_loss)
            #             torch.nn.utils.clip_grad_norm_(vae.parameters(), 5)
            return loss, recons

        @smp.step
        def get_codes_step(vae, images, k):
            images = images[:k]
            logits = vae.forward(images, return_logits=True)
            codebook_indices = logits.argmax(dim=1).flatten(1)
            return codebook_indices

        def hard_recons_step(dummy_decoder, dummy_codebook, codebook_indices):
            from functools import partial
            for module in dummy_codebook.modules():
                method = smp_state.patch_manager.get_original_method(
                    "forward", type(module))
                module.forward = partial(method, module)
            image_embeds = dummy_codebook.forward(codebook_indices)
            b, n, d = image_embeds.shape
            h = w = int(sqrt(n))

            image_embeds = rearrange(image_embeds,
                                     'b (h w) d -> b d h w',
                                     h=h,
                                     w=w)
            for module in dummy_decoder.modules():
                method = smp_state.patch_manager.get_original_method(
                    "forward", type(module))
                module.forward = partial(method, module)
            hard_recons = dummy_decoder.forward(image_embeds)
            return hard_recons

    except:
        pass

    # starting temperature

    global_step = 0
    temp = STARTING_TEMP

    for epoch in range(EPOCHS):
        ##
        batch_time = util.AverageMeter('Time', ':6.3f')
        data_time = util.AverageMeter('Data', ':6.3f')
        losses = util.AverageMeter('Loss', ':.4e')
        top1 = util.AverageMeter('Acc@1', ':6.2f')
        top5 = util.AverageMeter('Acc@5', ':6.2f')
        progress = util.ProgressMeter(
            len(dl), [batch_time, data_time, losses, top1, top5],
            prefix="Epoch: [{}]".format(epoch))

        vae.train()
        start = time.time()

        for i, (images, _) in enumerate(dl):
            images = images.to(device, non_blocking=True)
            opt.zero_grad()

            if args.model_parallel:
                loss, recons = train_step(vae, images, temp)
                # Rubik: Average the loss across microbatches.
                loss = loss.reduce_mean()
                recons = recons.reduce_mean()
            else:
                loss, recons = distr_vae(images,
                                         return_loss=True,
                                         return_recons=True,
                                         temp=temp)

            if (not args.model_parallel) and args.deepspeed:
                # Gradients are automatically zeroed after the step
                distr_vae.backward(loss)
                distr_vae.step()
            elif args.model_parallel:
                if args.amp:
                    scaler.step(opt)
                    scaler.update()
                else:
                    # some optimizers like adadelta from PT 1.8 dont like it when optimizer.step is called with no param
                    if len(list(vae.local_parameters())) > 0:
                        opt.step()
            else:
                loss.backward()
                opt.step()

            logs = {}

            if i % 10 == 0:
                if args.rank == 0:
                    #                 if deepspeed_utils.is_root_worker():
                    k = NUM_IMAGES_SAVE

                    with torch.no_grad():
                        if args.model_parallel:
                            model_dict = vae.state_dict()
                            model_dict_updated = {}
                            for key, val in model_dict.items():
                                if "decoder" in key:
                                    key = key.replace("decoder.", "")
                                elif "codebook" in key:
                                    key = key.replace("codebook.", "")
                                model_dict_updated[key] = val

                            dummy_decoder.load_state_dict(model_dict_updated,
                                                          strict=False)
                            dummy_codebook.load_state_dict(model_dict_updated,
                                                           strict=False)
                            codes = get_codes_step(vae, images, k)
                            codes = codes.reduce_mean().to(torch.long)
                            hard_recons = hard_recons_step(
                                dummy_decoder, dummy_codebook, codes)
                        else:
                            codes = vae.get_codebook_indices(images[:k])
                            hard_recons = vae.decode(codes)

                    images, recons = map(lambda t: t[:k], (images, recons))
                    images, recons, hard_recons, codes = map(
                        lambda t: t.detach().cpu(),
                        (images, recons, hard_recons, codes))
                    images, recons, hard_recons = map(
                        lambda t: make_grid(t.float(),
                                            nrow=int(sqrt(k)),
                                            normalize=True,
                                            range=(-1, 1)),
                        (images, recons, hard_recons))

#                     logs = {
#                         **logs,
#                         'sample images':        wandb.Image(images, caption = 'original images'),
#                         'reconstructions':      wandb.Image(recons, caption = 'reconstructions'),
#                         'hard reconstructions': wandb.Image(hard_recons, caption = 'hard reconstructions'),
#                         'codebook_indices':     wandb.Histogram(codes),
#                         'temperature':          temp
#                     }

                if args.model_parallel:
                    filename = f'{args.model_dir}/vae.pt'
                    if smp.dp_rank == 0:
                        if args.save_full_model:
                            model_dict = vae.state_dict()
                            opt_dict = opt.state_dict()
                            smp.save(
                                {
                                    "model_state_dict": model_dict,
                                    "optimizer_state_dict": opt_dict
                                },
                                filename,
                                partial=False,
                            )
                        else:
                            model_dict = vae.local_state_dict()
                            opt_dict = opt.local_state_dict()
                            smp.save(
                                {
                                    "model_state_dict": model_dict,
                                    "optimizer_state_dict": opt_dict
                                },
                                filename,
                                partial=True,
                            )
                    smp.barrier()

                else:
                    save_model(f'{args.model_dir}/vae.pt')
    #                     wandb.save(f'{args.model_dir}/vae.pt')

    # temperature anneal

                temp = max(temp * math.exp(-ANNEAL_RATE * global_step),
                           TEMP_MIN)

                # lr decay

                sched.step()

            # Collective loss, averaged
            if args.model_parallel:
                avg_loss = loss.detach().clone()
                #                 print("args.world_size : {}".format(args.world_size))
                avg_loss /= args.world_size

            else:
                avg_loss = deepspeed_utils.average_all(loss)

            if args.rank == 0:
                if i % 100 == 0:
                    lr = sched.get_last_lr()[0]
                    print(epoch, i, f'lr - {lr:6f}, loss - {avg_loss.item()},')

                    logs = {
                        **logs, 'epoch': epoch,
                        'iter': i,
                        'loss': avg_loss.item(),
                        'lr': lr
                    }

#                 wandb.log(logs)
            global_step += 1

            if args.rank == 0:
                # Every print_freq iterations, check the loss, accuracy, and speed.
                # For best performance, it doesn't make sense to print these metrics every
                # iteration, since they incur an allreduce and some host<->device syncs.

                # Measure accuracy
                #                 prec1, prec5 = util.accuracy(output, target, topk=(1, 5))

                # to_python_float incurs a host<->device sync
                losses.update(util.to_python_float(loss), images.size(0))
                #                 top1.update(util.to_python_float(prec1), images.size(0))
                #                 top5.update(util.to_python_float(prec5), images.size(0))

                # Waiting until finishing operations on GPU (Pytorch default: async)
                torch.cuda.synchronize()
                batch_time.update((time.time() - start) / args.log_interval)
                end = time.time()

                print(
                    'Epoch: [{0}][{1}/{2}] '
                    'Train_Time={batch_time.val:.3f}: avg-{batch_time.avg:.3f}, '
                    'Train_Speed={3:.3f} ({4:.3f}), '
                    'Train_Loss={loss.val:.10f}:({loss.avg:.4f}),'.format(
                        epoch,
                        i,
                        len(dl),
                        args.world_size * BATCH_SIZE / batch_time.val,
                        args.world_size * BATCH_SIZE / batch_time.avg,
                        batch_time=batch_time,
                        loss=losses))

#         if deepspeed_utils.is_root_worker():
# save trained model to wandb as an artifact every epoch's end

#             model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config))
#             model_artifact.add_file(f'{args.model_dir}/vae.pt')
#             run.log_artifact(model_artifact)

    if args.rank == 0:
        #     if deepspeed_utils.is_root_worker():
        # save final vae and cleanup
        if args.model_parallel:
            logger.debug('save model_parallel')
        else:
            save_model(os.path.join(args.model_dir, 'vae-final.pt'))


#         wandb.save(f'{args.model_dir}/vae-final.pt')

#         model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config))
#         model_artifact.add_file(f'{args.model_dir}/vae-final.pt')
#         run.log_artifact(model_artifact)

#         wandb.finish()

    if args.model_parallel:
        if args.assert_losses:
            if args.horovod or args.ddp:
                # SM Distributed: If using data parallelism, gather all losses across different model
                # replicas and check if losses match.

                losses = smp.allgather(loss, smp.DP_GROUP)
                for l in losses:
                    print(l)
                    assert math.isclose(l, losses[0])

                assert loss < 0.18
            else:
                assert loss < 0.08

        smp.barrier()
        print("SMP training finished successfully")
Пример #27
0
class TrackerSiamFC(Tracker):
    def __init__(self, net_path=None, **kwargs):
        super(TrackerSiamFC, self).__init__('SiamFC', True)
        self.cfg = self.parse_args(**kwargs)

        # setup GPU device if available
        self.cuda = torch.cuda.is_available()
        self.device = torch.device('cuda:0' if self.cuda else 'cpu')

        # setup model
        self.net = Net(backbone=AlexNetV1(), head=SiamFC(self.cfg.out_scale))
        ops.init_weights(self.net)

        # load checkpoint if provided
        if net_path is not None:
            self.net.load_state_dict(
                torch.load(net_path,
                           map_location=lambda storage, loc: storage))
        self.net = self.net.to(self.device)

        # setup criterion
        self.criterion = BalancedLoss()
        # print("loss function:")
        # print(self.criterion)

        # setup optimizer
        self.optimizer = optim.SGD(self.net.parameters(),
                                   lr=self.cfg.initial_lr,
                                   weight_decay=self.cfg.weight_decay,
                                   momentum=self.cfg.momentum)

        # setup lr scheduler
        gamma = np.power(self.cfg.ultimate_lr / self.cfg.initial_lr,
                         1.0 / self.cfg.epoch_num)
        self.lr_scheduler = ExponentialLR(self.optimizer, gamma)

    def parse_args(self, **kwargs):
        # default parameters
        cfg = {
            # basic parameters
            'out_scale': 0.001,
            'exemplar_sz': 127,
            'instance_sz': 255,
            'context': 0.5,
            # inference parameters
            'scale_num': 3,
            'scale_step': 1.0375,
            'scale_lr': 0.59,
            'scale_penalty': 0.9745,
            'window_influence': 0.176,
            'response_sz': 17,
            'response_up': 16,
            'total_stride': 8,
            # train parameters
            'epoch_num': 50,
            'batch_size': 8,
            'num_workers': 32,
            'initial_lr': 1e-2,
            'ultimate_lr': 1e-5,
            'weight_decay': 5e-4,
            'momentum': 0.9,
            'r_pos': 16,
            'r_neg': 0
        }

        for key, val in kwargs.items():
            if key in cfg:
                cfg.update({key: val})
        return namedtuple('Config', cfg.keys())(**cfg)

    @torch.no_grad()
    def init(self, img, box):
        # set to evaluation mode
        self.net.eval()

        # convert box to 0-indexed and center based [y, x, h, w]
        box = np.array([
            box[1] - 1 + (box[3] - 1) / 2, box[0] - 1 +
            (box[2] - 1) / 2, box[3], box[2]
        ],
                       dtype=np.float32)
        self.center, self.target_sz = box[:2], box[2:]

        # create hanning window
        self.upscale_sz = self.cfg.response_up * self.cfg.response_sz
        self.hann_window = np.outer(np.hanning(self.upscale_sz),
                                    np.hanning(self.upscale_sz))
        self.hann_window /= self.hann_window.sum()

        # search scale factors
        self.scale_factors = self.cfg.scale_step**np.linspace(
            -(self.cfg.scale_num // 2), self.cfg.scale_num // 2,
            self.cfg.scale_num)

        # exemplar and search sizes
        context = self.cfg.context * np.sum(self.target_sz)
        self.z_sz = np.sqrt(np.prod(self.target_sz + context))
        self.x_sz = self.z_sz * \
                    self.cfg.instance_sz / self.cfg.exemplar_sz

        # exemplar image
        self.avg_color = np.mean(img, axis=(0, 1))
        z = ops.crop_and_resize(img,
                                self.center,
                                self.z_sz,
                                out_size=self.cfg.exemplar_sz,
                                border_value=self.avg_color)

        # exemplar features
        z = torch.from_numpy(z).to(self.device).permute(
            2, 0, 1).unsqueeze(0).float()
        self.kernel = self.net.backbone(z)

    @torch.no_grad()
    def update(self, img, f):
        # set to evaluation mode
        self.net.eval()
        x = [
            ops.crop_and_resize(img,
                                self.center,
                                self.x_sz * f,
                                out_size=self.cfg.instance_sz,
                                border_value=self.avg_color)
            for f in self.scale_factors
        ]
        x = np.stack(x, axis=0)
        x = torch.from_numpy(x).to(self.device).permute(0, 3, 1, 2).float()

        # responses
        x = self.net.backbone(x)
        responses = self.net.head(self.kernel, x)
        responses = responses.squeeze(1).cpu().numpy()

        # upsample responses and penalize scale changes
        responses = np.stack([
            cv2.resize(u, (self.upscale_sz, self.upscale_sz),
                       interpolation=cv2.INTER_CUBIC) for u in responses
        ])
        responses[:self.cfg.scale_num // 2] *= self.cfg.scale_penalty
        responses[self.cfg.scale_num // 2 + 1:] *= self.cfg.scale_penalty

        # peak scale
        scale_id = np.argmax(np.amax(responses, axis=(1, 2)))

        # peak location
        response = responses[scale_id]
        response -= response.min()
        response /= response.sum() + 1e-16
        response = (1 - self.cfg.window_influence) * response + \
                   self.cfg.window_influence * self.hann_window
        loc = np.unravel_index(response.argmax(), response.shape)
        # print(loc)

        # locate target center
        disp_in_response = np.array(loc) - (self.upscale_sz - 1) / 2
        disp_in_instance = disp_in_response * \
                           self.cfg.total_stride / self.cfg.response_up
        disp_in_image = disp_in_instance * self.x_sz * \
                        self.scale_factors[scale_id] / self.cfg.instance_sz

        self.center = self.center + disp_in_image

        # update target size
        scale = (1 - self.cfg.scale_lr) * 1.0 + \
                self.cfg.scale_lr * self.scale_factors[scale_id]
        self.target_sz *= scale
        self.z_sz *= scale
        self.x_sz *= scale

        x_siamFC = self.center[1] + 1 - (self.target_sz[1] - 1) / 2
        y_siamFC = self.center[0] + 1 - (self.target_sz[0] - 1) / 2

        box = np.array(
            [x_siamFC, y_siamFC, self.target_sz[1], self.target_sz[0]])
        return box

    def track(self, img_files, box, visualize=False):
        frame_num = len(img_files)
        boxes = np.zeros((frame_num, 4))
        boxes[0] = box
        times = np.zeros(frame_num)

        for f, img_file in enumerate(img_files):
            img = ops.read_image(img_file)

            begin = time.time()
            if f == 0:
                self.init(img, box)
            else:
                boxes[f, :] = self.update(img, f)
            times[f] = time.time() - begin
            # print(boxes[f, :])
            if visualize:
                ops.show_image(img, boxes[f, :])

        return boxes, times

    def train_step(self, batch, backward=True):
        # set network mode
        self.net.train(backward)

        # parse batch data
        z = batch[0].to(self.device, non_blocking=self.cuda)
        x = batch[1].to(self.device, non_blocking=self.cuda)

        with torch.set_grad_enabled(backward):
            # inference
            responses = self.net(z, x)

            # calculate loss
            labels = self._create_labels(responses.size())
            loss = self.criterion(responses, labels)

            if backward:
                # back propagation
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

        return loss.item()

    @torch.enable_grad()
    def train_over(self, seqs, val_seqs=None, save_dir='defaultpretrained'):
        # set to train mode
        self.net.train()

        # create save_dir folder
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        # setup dataset
        transforms = SiamFCTransforms(exemplar_sz=self.cfg.exemplar_sz,
                                      instance_sz=self.cfg.instance_sz,
                                      context=self.cfg.context)
        dataset = Pair(seqs=seqs, transforms=transforms)

        # setup dataloader
        dataloader = DataLoader(dataset,
                                batch_size=self.cfg.batch_size,
                                shuffle=True,
                                num_workers=self.cfg.num_workers,
                                pin_memory=self.cuda,
                                drop_last=True)

        # loop over epochs
        for epoch in range(self.cfg.epoch_num):
            # update lr at each epoch
            self.lr_scheduler.step(epoch=epoch)

            # loop over dataloader
            for it, batch in enumerate(dataloader):
                loss = self.train_step(batch, backward=True)
                print('Epoch: {} [{}/{}] Loss: {:.5f}'.format(
                    epoch + 1, it + 1, len(dataloader), loss))
                sys.stdout.flush()

            # save checkpoint
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            net_path = os.path.join(save_dir,
                                    'siamfc_alexnet_e%d.pth' % (epoch + 1))
            torch.save(self.net.state_dict(), net_path)

    def _create_labels(self, size):
        # skip if same sized labels already created
        if hasattr(self, 'labels') and self.labels.size() == size:
            return self.labels

        def logistic_labels(x, y, r_pos, r_neg):
            dist = np.abs(x) + np.abs(y)  # block distance
            labels = np.where(
                dist <= r_pos, np.ones_like(x),
                np.where(dist < r_neg,
                         np.ones_like(x) * 0.5, np.zeros_like(x)))
            return labels

        # distances along x- and y-axis
        n, c, h, w = size
        x = np.arange(w) - (w - 1) / 2
        y = np.arange(h) - (h - 1) / 2
        x, y = np.meshgrid(x, y)

        # create logistic labels
        r_pos = self.cfg.r_pos / self.cfg.total_stride
        r_neg = self.cfg.r_neg / self.cfg.total_stride
        labels = logistic_labels(x, y, r_pos, r_neg)

        # repeat to size
        labels = labels.reshape((1, 1, h, w))
        labels = np.tile(labels, (n, c, 1, 1))

        # convert to tensors
        self.labels = torch.from_numpy(labels).to(self.device).float()

        return self.labels
Пример #28
0
 def test_exp_lr(self):
     epochs = 10
     single_targets = [0.05 * (0.9 ** x) for x in range(epochs)]
     targets = [single_targets, list(map(lambda x: x * epochs, single_targets))]
     scheduler = ExponentialLR(self.opt, gamma=0.9)
     self._test(scheduler, targets, epochs)
Пример #29
0
    def train_and_eval(self):
        print("Training the model...")
        self.entity_idxs = {d.entities[i]: i for i in range(len(d.entities))}
        self.relation_idxs = {
            d.relations[i]: i
            for i in range(len(d.relations))
        }

        train_data_idxs = self.get_data_idxs(d.train_data)
        print("Number of training data points: %d" % len(train_data_idxs))

        model = GETD(d, self.ent_vec_dim, self.rel_vec_dim, self.k, self.ni,
                     self.ranks, device, **self.kwargs)
        model = model.to(device)

        opt = torch.optim.Adam(model.parameters(), lr=self.learning_rate)
        if self.decay_rate:
            scheduler = ExponentialLR(opt, self.decay_rate)

        print("Starting training...")
        best_valid_iter = 0
        best_valid_metric = {
            'mrr': -1,
            'test_mrr': -1,
            'test_hit1': -1,
            'test_hit3': -1,
            'test_hit10': -1
        }

        er_vocab = self.get_er_vocab(train_data_idxs)
        er_vocab_pairs = list(er_vocab.keys())
        for it in range(1, self.num_iterations + 1):
            model.train()
            losses = []
            np.random.shuffle(er_vocab_pairs)

            for j in range(0, len(er_vocab_pairs), self.batch_size):
                data_batch, label = self.get_batch(er_vocab, er_vocab_pairs, j)
                opt.zero_grad()
                e1_idx = torch.tensor(data_batch[:, 0],
                                      dtype=torch.long).to(device)
                r_idx = torch.tensor(data_batch[:, 1],
                                     dtype=torch.long).to(device)

                pred, W = model.forward(e1_idx, r_idx)
                pred = pred.to(device)
                loss = model.loss(pred, label)
                loss.backward()
                opt.step()

                losses.append(loss.item())

            print('\nEpoch %d train, loss=%f' % (it, np.mean(losses, axis=0)))

            if self.decay_rate:
                scheduler.step()

            model.eval()
            with torch.no_grad():
                v_mrr, v_hit10, v_hit3, v_hit1 = self.evaluate(
                    model, d.valid_data, W)
                print(
                    'Epoch %d valid, MRR=%.8f, Hits@10=%f, Hits@3=%f, Hits@1=%f'
                    % (it, v_mrr, v_hit10, v_hit3, v_hit1))
                t_mrr, t_hit10, t_hit3, t_hit1 = self.evaluate(
                    model, d.test_data, W)

                if v_mrr > best_valid_metric['mrr']:
                    best_valid_iter = it
                    print('======== MRR on validation set increases ======== ')
                    best_valid_metric['mrr'] = v_mrr
                    best_valid_metric['test_mrr'] = t_mrr
                    best_valid_metric['test_hit1'] = t_hit1
                    best_valid_metric['test_hit3'] = t_hit3
                    best_valid_metric['test_hit10'] = t_hit10
                else:
                    print(
                        '====Current Epoch:%d, Best Epoch:%d, valid_MRR didn\'t increase for %d Epoch, best test_MRR=%f'
                        % (it, best_valid_iter, it - best_valid_iter,
                           best_valid_metric['test_mrr']))
                print(
                    'Epoch %d test, MRR=%.8f, Hits@10=%f, Hits@3=%f, Hits@1=%f'
                    % (it, t_mrr, t_hit10, t_hit3, t_hit1))

            if (it - best_valid_iter) >= 10 or it == self.num_iterations:
                print('++++++++++++ Early Stopping +++++++++++++')
                print('Best epoch %d' % best_valid_iter)
                print('Mean reciprocal rank: {0}'.format(
                    best_valid_metric['test_mrr']))
                print('Hits @10: {0}'.format(best_valid_metric['test_hit10']))
                print('Hits @3: {0}'.format(best_valid_metric['test_hit3']))
                print('Hits @1: {0}'.format(best_valid_metric['test_hit1']))
                break
Пример #30
0
 def test_exp_step_lr_state_dict(self):
     self._check_scheduler_state_dict(
         lambda: ExponentialLR(self.opt, gamma=0.1),
         lambda: ExponentialLR(self.opt, gamma=0.01))
Пример #31
0
 def set_params(self, transformer, validation_datagen, *args, **kwargs):
     self.validation_datagen = validation_datagen
     self.model = transformer.model
     self.optimizer = transformer.optimizer
     self.loss_function = transformer.loss_function
     self.lr_scheduler = ExponentialLR(self.optimizer, self.gamma, last_epoch=-1)
Пример #32
0
    def train(
        self, base_path: Union[Path, str],
        fix_len=20,
        min_freq=2,
        buckets=1000,
        batch_size=5000,
        lr=2e-3,
        mu=.9,
        nu=.9,
        epsilon=1e-12,
        clip=5.0,
        decay=.75,
        decay_steps=5000,
        patience=100,
        max_epochs=10,
        wandb=None
    ):
        r"""
        Train any class that implement model interface

        Args:
            base_path (object): Main path to which all output during training is logged and models are saved
            max_epochs: Maximum number of epochs to train. Terminates training if this number is surpassed.
            patience:
            decay_steps:
            decay:
            clip:
            epsilon:
            nu:
            mu:
            lr:
            proj:
            tree:
            batch_size:
            buckets:
            min_freq:
            fix_len:


        """
        ################################################################################################################
        # BUILD
        ################################################################################################################
        feat = self.parser.feat
        embed = self.parser.embed
        os.makedirs(os.path.dirname(base_path), exist_ok=True)
        logger.info("Building the fields")
        WORD = Field('words', pad=pad, unk=unk, bos=bos, lower=True)
        if feat == 'char':
            FEAT = SubwordField('chars', pad=pad, unk=unk, bos=bos, fix_len=fix_len)
        elif feat == 'bert':
            from transformers import AutoTokenizer
            tokenizer = AutoTokenizer.from_pretrained(self.parser.bert)
            FEAT = SubwordField('bert',
                                pad=tokenizer.pad_token,
                                unk=tokenizer.unk_token,
                                bos=tokenizer.bos_token or tokenizer.cls_token,
                                fix_len=fix_len,
                                tokenize=tokenizer.tokenize)
            FEAT.vocab = tokenizer.get_vocab()
        else:
            FEAT = Field('tags', bos=bos)

        ARC = Field('arcs', bos=bos, use_vocab=False, fn=CoNLL.get_arcs)
        REL = Field('rels', bos=bos)
        if feat in ('char', 'bert'):
            transform = CoNLL(FORM=(WORD, FEAT), HEAD=ARC, DEPREL=REL)
        else:
            transform = CoNLL(FORM=WORD, CPOS=FEAT, HEAD=ARC, DEPREL=REL)

        train = Dataset(transform, self.corpus.train)
        WORD.build(train, min_freq, (Embedding.load(embed, unk) if self.parser.embed else None))
        FEAT.build(train)
        REL.build(train)
        n_words = WORD.vocab.n_init
        n_feats = len(FEAT.vocab)
        n_rels = len(REL.vocab)
        pad_index = WORD.pad_index
        unk_index = WORD.unk_index
        feat_pad_index = FEAT.pad_index
        parser = DependencyParser(
            n_words=n_words,
            n_feats=n_feats,
            n_rels=n_rels,
            pad_index=pad_index,
            unk_index=unk_index,
            feat_pad_index=feat_pad_index,
            transform=transform,
            feat=self.parser.feat,
            bert=self.parser.bert
        )
        # word_field_embeddings = self.parser.embeddings[0]
        # word_field_embeddings.n_vocab = 100
        parser.embeddings = self.parser.embeddings
        # parser.embeddings[0] = word_field_embeddings
        parser.load_pretrained(WORD.embed).to(device)

        ################################################################################################################
        # TRAIN
        ################################################################################################################
        if wandb:
            wandb.watch(parser)
        parser.transform.train()
        if dist.is_initialized():
            batch_size = batch_size // dist.get_world_size()
        logger.info('Loading the data')
        train = Dataset(parser.transform, self.corpus.train)
        dev = Dataset(parser.transform, self.corpus.dev)
        test = Dataset(parser.transform, self.corpus.test)
        train.build(batch_size, buckets, True, dist.is_initialized())
        dev.build(batch_size, buckets)
        test.build(batch_size, buckets)
        logger.info(f"\n{'train:':6} {train}\n{'dev:':6} {dev}\n{'test:':6} {test}\n")
        logger.info(f'{parser}')
        if dist.is_initialized():
            parser = DDP(parser, device_ids=[dist.get_rank()], find_unused_parameters=True)

        optimizer = Adam(parser.parameters(), lr, (mu, nu), epsilon)
        scheduler = ExponentialLR(optimizer, decay ** (1 / decay_steps))

        elapsed = timedelta()
        best_e, best_metric = 1, Metric()

        for epoch in range(1, max_epochs + 1):
            start = datetime.now()
            logger.info(f'Epoch {epoch} / {max_epochs}:')

            parser.train()

            bar = progress_bar(train.loader)
            metric = AttachmentMetric()
            for words, feats, arcs, rels in bar:
                optimizer.zero_grad()

                mask = words.ne(parser.WORD.pad_index)
                # ignore the first token of each sentence
                mask[:, 0] = 0
                s_arc, s_rel = parser.forward(words, feats)
                loss = parser.forward_loss(s_arc, s_rel, arcs, rels, mask)
                loss.backward()
                nn.utils.clip_grad_norm_(parser.parameters(), clip)
                optimizer.step()
                scheduler.step()

                arc_preds, rel_preds = parser.decode(s_arc, s_rel, mask)
                # ignore all punctuation if not specified
                if not self.parser.args['punct']:
                    mask &= words.unsqueeze(-1).ne(parser.puncts).all(-1)
                metric(arc_preds, rel_preds, arcs, rels, mask)
                bar.set_postfix_str(f'lr: {scheduler.get_last_lr()[0]:.4e} - loss: {loss:.4f} - {metric}')

            dev_loss, dev_metric = parser.evaluate(dev.loader)
            logger.info(f"{'dev:':6} - loss: {dev_loss:.4f} - {dev_metric}")
            test_loss, test_metric = parser.evaluate(test.loader)
            logger.info(f"{'test:':6} - loss: {test_loss:.4f} - {test_metric}")
            if wandb:
                wandb.log({"test_loss": test_loss})
                wandb.log({"test_metric_uas": test_metric.uas})
                wandb.log({"test_metric_las": test_metric.las})

            t = datetime.now() - start
            # save the model if it is the best so far
            if dev_metric > best_metric:
                best_e, best_metric = epoch, dev_metric
                if is_master():
                    parser.save(base_path)
                logger.info(f'{t}s elapsed (saved)\n')
            else:
                logger.info(f'{t}s elapsed\n')
            elapsed += t
            if epoch - best_e >= patience:
                break
        loss, metric = parser.load(base_path).evaluate(test.loader)

        logger.info(f'Epoch {best_e} saved')
        logger.info(f"{'dev:':6} - {best_metric}")
        logger.info(f"{'test:':6} - {metric}")
        logger.info(f'{elapsed}s elapsed, {elapsed / epoch}s/epoch')