class ExponentialLRScheduler(Callback):
    def __init__(self, gamma, epoch_every=1, batch_every=None):
        super().__init__()
        self.gamma = gamma
        if epoch_every == 0:
            self.epoch_every = False
        else:
            self.epoch_every = epoch_every
        if batch_every == 0:
            self.batch_every = False
        else:
            self.batch_every = batch_every

    def set_params(self, transformer, validation_datagen, *args, **kwargs):
        self.validation_datagen = validation_datagen
        self.model = transformer.model
        self.optimizer = transformer.optimizer
        self.loss_function = transformer.loss_function
        self.lr_scheduler = ExponentialLR(self.optimizer, self.gamma, last_epoch=-1)

    def on_train_begin(self, *args, **kwargs):
        self.epoch_id = 0
        self.batch_id = 0
        logger.info('initial lr: {0}'.format(self.optimizer.state_dict()['param_groups'][0]['initial_lr']))

    def on_epoch_end(self, *args, **kwargs):
        if self.epoch_every and (((self.epoch_id + 1) % self.epoch_every) == 0):
            self.lr_scheduler.step()
            logger.info('epoch {0} current lr: {1}'.format(self.epoch_id + 1,
                                                           self.optimizer.state_dict()['param_groups'][0]['lr']))
        self.epoch_id += 1

    def on_batch_end(self, *args, **kwargs):
        if self.batch_every and ((self.batch_id % self.batch_every) == 0):
            self.lr_scheduler.step()
            logger.info('epoch {0} batch {1} current lr: {2}'.format(
                self.epoch_id + 1, self.batch_id + 1, self.optimizer.state_dict()['param_groups'][0]['lr']))
        self.batch_id += 1
Exemple #2
0
class TrackerSiamFC(Tracker):
    def __init__(self, net_path=None, **kwargs):
        super(TrackerSiamFC, self).__init__('SiamFC', True)
        self.cfg = self.parse_args(**kwargs)

        # setup GPU device if available
        self.cuda = torch.cuda.is_available()
        self.device = torch.device('cuda:0' if self.cuda else 'cpu')

        # setup model
        # self.net = Net(
        #     backbone=AlexNetV1(),
        #     head=SiamFC(self.cfg.out_scale)
        # )
        self.net = Net(backbone=CapsuleNetV1(),
                       head=SiamFC_V2(self.cfg.out_scale))
        ops.init_weights(self.net)

        # load checkpoint if provided
        if net_path is not None:
            self.net.load_state_dict(
                torch.load(net_path,
                           map_location=lambda storage, loc: storage))
        self.net = self.net.to(self.device)

        # setup criterion
        self.criterion = BalancedLoss()

        # setup optimizer
        self.optimizer = optim.SGD(self.net.parameters(),
                                   lr=self.cfg.initial_lr,
                                   weight_decay=self.cfg.weight_decay,
                                   momentum=self.cfg.momentum)

        # setup lr scheduler
        gamma = np.power(self.cfg.ultimate_lr / self.cfg.initial_lr,
                         1.0 / self.cfg.epoch_num)
        self.lr_scheduler = ExponentialLR(self.optimizer, gamma)

    def parse_args(self, **kwargs):
        # default parameters
        cfg = {
            # basic parameters
            'out_scale': 0.001,
            'exemplar_sz': 127,
            'instance_sz': 255,
            'context': 0.5,
            # inference parameters
            'scale_num': 3,
            'scale_step': 1.0375,
            'scale_lr': 0.59,
            'scale_penalty': 0.9745,
            'window_influence': 0.176,
            'response_sz': 17,
            'response_up': 16,
            'total_stride': 8,
            # train parameters
            'epoch_num': 50,
            'batch_size': 8,
            'num_workers': 32,
            'initial_lr': 1e-2,
            'ultimate_lr': 1e-5,
            'weight_decay': 5e-4,
            'momentum': 0.9,
            'r_pos': 16,
            'r_neg': 0
        }

        for key, val in kwargs.items():
            if key in cfg:
                cfg.update({key: val})
        return namedtuple('Config', cfg.keys())(**cfg)

    @torch.no_grad()
    def init(self, img, box):
        # set to evaluation mode
        self.net.eval()

        # convert box to 0-indexed and center based [y, x, h, w]
        box = np.array([
            box[1] - 1 + (box[3] - 1) / 2, box[0] - 1 +
            (box[2] - 1) / 2, box[3], box[2]
        ],
                       dtype=np.float32)
        self.center, self.target_sz = box[:2], box[2:]

        # create hanning window
        self.upscale_sz = self.cfg.response_up * self.cfg.response_sz
        self.hann_window = np.outer(np.hanning(self.upscale_sz),
                                    np.hanning(self.upscale_sz))
        self.hann_window /= self.hann_window.sum()

        # search scale factors
        self.scale_factors = self.cfg.scale_step**np.linspace(
            -(self.cfg.scale_num // 2), self.cfg.scale_num // 2,
            self.cfg.scale_num)

        # exemplar and search sizes
        context = self.cfg.context * np.sum(self.target_sz)
        self.z_sz = np.sqrt(np.prod(self.target_sz + context))
        self.x_sz = self.z_sz * \
            self.cfg.instance_sz / self.cfg.exemplar_sz

        # exemplar image
        self.avg_color = np.mean(img, axis=(0, 1))
        z = ops.crop_and_resize(img,
                                self.center,
                                self.z_sz,
                                out_size=self.cfg.exemplar_sz,
                                border_value=self.avg_color)

        # exemplar features
        z = torch.from_numpy(z).to(self.device).permute(
            2, 0, 1).unsqueeze(0).float()
        self.kernel = self.net.backbone(z)

    @torch.no_grad()
    def update(self, img):
        # set to evaluation mode
        self.net.eval()

        # search images
        x = [
            ops.crop_and_resize(img,
                                self.center,
                                self.x_sz * f,
                                out_size=self.cfg.instance_sz,
                                border_value=self.avg_color)
            for f in self.scale_factors
        ]
        x = np.stack(x, axis=0)
        x = torch.from_numpy(x).to(self.device).permute(0, 3, 1, 2).float()

        # responses
        x = self.net.backbone(x)
        responses = self.net.head(self.kernel, x)
        responses = responses.squeeze(1).cpu().numpy()

        # upsample responses and penalize scale changes
        responses = np.stack([
            cv2.resize(u, (self.upscale_sz, self.upscale_sz),
                       interpolation=cv2.INTER_CUBIC) for u in responses
        ])
        responses[:self.cfg.scale_num // 2] *= self.cfg.scale_penalty
        responses[self.cfg.scale_num // 2 + 1:] *= self.cfg.scale_penalty

        # peak scale
        scale_id = np.argmax(np.amax(responses, axis=(1, 2)))

        # peak location
        response = responses[scale_id]
        response -= response.min()
        response /= response.sum() + 1e-16
        response = (1 - self.cfg.window_influence) * response + \
            self.cfg.window_influence * self.hann_window
        loc = np.unravel_index(response.argmax(), response.shape)

        # locate target center
        disp_in_response = np.array(loc) - (self.upscale_sz - 1) / 2
        disp_in_instance = disp_in_response * \
            self.cfg.total_stride / self.cfg.response_up
        disp_in_image = disp_in_instance * self.x_sz * \
            self.scale_factors[scale_id] / self.cfg.instance_sz
        self.center += disp_in_image

        # update target size
        scale =  (1 - self.cfg.scale_lr) * 1.0 + \
            self.cfg.scale_lr * self.scale_factors[scale_id]
        self.target_sz *= scale
        self.z_sz *= scale
        self.x_sz *= scale

        # return 1-indexed and left-top based bounding box
        box = np.array([
            self.center[1] + 1 - (self.target_sz[1] - 1) / 2,
            self.center[0] + 1 - (self.target_sz[0] - 1) / 2,
            self.target_sz[1], self.target_sz[0]
        ])

        return box

    def track(self, img_files, box, visualize=False):
        frame_num = len(img_files)
        boxes = np.zeros((frame_num, 4))
        boxes[0] = box
        times = np.zeros(frame_num)

        for f, img_file in enumerate(img_files):
            img = ops.read_image(img_file)

            begin = time.time()
            if f == 0:
                self.init(img, box)
            else:
                boxes[f, :] = self.update(img)
            times[f] = time.time() - begin

            if visualize:
                ops.show_image(img, boxes[f, :])

        return boxes, times

    def train_step(self, batch, backward=True):
        # set network mode
        self.net.train(backward)

        # parse batch data
        z = batch[0].to(self.device, non_blocking=self.cuda)
        x = batch[1].to(self.device, non_blocking=self.cuda)

        with torch.set_grad_enabled(backward):
            # inference
            responses = self.net(z, x)

            # calculate loss
            labels = self._create_labels(responses.size())
            loss = self.criterion(responses, labels)

            if backward:
                # back propagation
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

        return loss.item()

    @torch.enable_grad()
    def train_over(self, seqs, val_seqs=None, save_dir='pretrained'):
        # set to train mode
        self.net.train()

        # create save_dir folder
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        # setup dataset
        transforms = SiamFCTransforms(exemplar_sz=self.cfg.exemplar_sz,
                                      instance_sz=self.cfg.instance_sz,
                                      context=self.cfg.context)
        dataset = Pair(seqs=seqs, transforms=transforms)

        # setup dataloader
        dataloader = DataLoader(dataset,
                                batch_size=self.cfg.batch_size,
                                shuffle=True,
                                num_workers=self.cfg.num_workers,
                                pin_memory=self.cuda,
                                drop_last=True)

        # loop over epochs
        for epoch in range(self.cfg.epoch_num):
            # update lr at each epoch
            self.lr_scheduler.step(epoch=epoch)

            # loop over dataloader
            for it, batch in enumerate(dataloader):
                loss = self.train_step(batch, backward=True)
                print('Epoch: {} [{}/{}] Loss: {:.5f}'.format(
                    epoch + 1, it + 1, len(dataloader), loss))
                sys.stdout.flush()

            # save checkpoint
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            net_path = os.path.join(save_dir,
                                    'siamfc_alexnet_e%d.pth' % (epoch + 1))
            torch.save(self.net.state_dict(), net_path)

    def _create_labels(self, size):
        # skip if same sized labels already created
        if hasattr(self, 'labels') and self.labels.size() == size:
            return self.labels

        def logistic_labels(x, y, r_pos, r_neg):
            dist = np.abs(x) + np.abs(y)  # block distance
            labels = np.where(
                dist <= r_pos, np.ones_like(x),
                np.where(dist < r_neg,
                         np.ones_like(x) * 0.5, np.zeros_like(x)))
            return labels

        # distances along x- and y-axis
        n, c, h, w = size
        x = np.arange(w) - (w - 1) / 2
        y = np.arange(h) - (h - 1) / 2
        x, y = np.meshgrid(x, y)

        # create logistic labels
        r_pos = self.cfg.r_pos / self.cfg.total_stride
        r_neg = self.cfg.r_neg / self.cfg.total_stride
        labels = logistic_labels(x, y, r_pos, r_neg)

        # repeat to size
        labels = labels.reshape((1, 1, h, w))
        labels = np.tile(labels, (n, c, 1, 1))

        # convert to tensors
        self.labels = torch.from_numpy(labels).to(self.device).float()

        return self.labels
Exemple #3
0
    def train_and_eval(self):
        torch.set_num_threads(2)
        best_valid = [0, 0, 0, 0, 0]
        best_test = [0, 0, 0, 0, 0]
        self.entity_idxs = {d.entities[i]: i for i in range(len(d.entities))}
        self.relation_idxs = {
            d.relations[i]: i
            for i in range(len(d.relations))
        }
        f = open('../data/' + self.dataset + '/entities.dict', 'w')
        for key, value in self.entity_idxs.items():
            f.write(key + '\t' + str(value) + '\n')
        f.close()
        f = open('../data/' + self.dataset + '/relations.dict', 'w')
        for key, value in self.relation_idxs.items():
            f.write(key + '\t' + str(value) + '\n')
        f.close()
        train_data_idxs = self.get_data_idxs(d.train_data)
        print("Number of training data points: %d" % len(train_data_idxs))
        print('Entities: %d' % len(self.entity_idxs))
        print('Relations: %d' % len(self.relation_idxs))
        model = TuckER(d, self.ent_vec_dim, self.rel_vec_dim, **self.kwargs)
        model.init()
        if self.load_from != '':
            fname = self.load_from
            checkpoint = torch.load(fname)
            model.load_state_dict(checkpoint)
        if self.cuda:
            model.cuda()
        opt = torch.optim.Adam(model.parameters(), lr=self.learning_rate)
        if self.decay_rate:
            scheduler = ExponentialLR(opt, self.decay_rate)

        er_vocab = self.get_er_vocab(train_data_idxs)
        er_vocab_pairs = list(er_vocab.keys())

        print("Starting training...")

        for it in range(1, self.num_iterations + 1):
            start_train = time.time()
            model.train()
            losses = []
            np.random.shuffle(er_vocab_pairs)
            for j in tqdm(range(0, len(er_vocab_pairs), self.batch_size)):
                data_batch, targets = self.get_batch(er_vocab, er_vocab_pairs,
                                                     j)
                opt.zero_grad()
                e1_idx = torch.tensor(data_batch[:, 0])
                r_idx = torch.tensor(data_batch[:, 1])
                if self.cuda:
                    e1_idx = e1_idx.cuda()
                    r_idx = r_idx.cuda()
                predictions = model.forward(e1_idx, r_idx)
                if self.label_smoothing:
                    targets = ((1.0 - self.label_smoothing) *
                               targets) + (1.0 / targets.size(1))
                loss = model.loss(predictions, targets)
                loss.backward()
                opt.step()
                losses.append(loss.item())
            if self.decay_rate:
                scheduler.step()
            if it % 100 == 0:
                print('Epoch', it, ' Epoch time',
                      time.time() - start_train, ' Loss:', np.mean(losses))
            model.eval()

            with torch.no_grad():
                if it % self.valid_steps == 0:
                    start_test = time.time()
                    print("Validation:")
                    valid = self.evaluate(model, d.valid_data)
                    print("Test:")
                    test = self.evaluate(model, d.test_data)
                    valid_mrr = valid[0]
                    test_mrr = test[0]
                    if valid_mrr >= best_valid[0]:
                        best_valid = valid
                        best_test = test
                        print('Validation MRR increased.')
                        print('Saving model...')
                        write_embedding_files(model)
                        print('Model saved!')

                    print('Best valid:', best_valid)
                    print('Best Test:', best_test)
                    print('Dataset:', self.dataset)
                    print('Model:', self.model)

                    print(time.time() - start_test)
                    print(
                        'Learning rate %f | Decay %f | Dim %d | Input drop %f | Hidden drop 2 %f | LS %f | Batch size %d | Loss type %s | L3 reg %f'
                        %
                        (self.learning_rate, self.decay_rate, self.ent_vec_dim,
                         self.kwargs["input_dropout"],
                         self.kwargs["hidden_dropout2"], self.label_smoothing,
                         self.batch_size, self.loss_type, self.l3_reg))
Exemple #4
0
def main(rank, args):

    # Distributed setup

    if args.distributed:
        setup_distributed(rank, args.world_size)

    not_main_rank = args.distributed and rank != 0

    logging.info("Start time: %s", datetime.now())

    # Explicitly set seed to make sure models created in separate processes
    # start from same random weights and biases
    torch.manual_seed(args.seed)

    # Empty CUDA cache
    torch.cuda.empty_cache()

    # Change backend for flac files
    torchaudio.set_audio_backend("soundfile")

    # Transforms

    melkwargs = {
        "n_fft": args.win_length,
        "n_mels": args.n_bins,
        "hop_length": args.hop_length,
    }

    sample_rate_original = 16000

    if args.type == "mfcc":
        transforms = torch.nn.Sequential(
            torchaudio.transforms.MFCC(
                sample_rate=sample_rate_original,
                n_mfcc=args.n_bins,
                melkwargs=melkwargs,
            ), )
        num_features = args.n_bins
    elif args.type == "waveform":
        transforms = torch.nn.Sequential(UnsqueezeFirst())
        num_features = 1
    else:
        raise ValueError("Model type not supported")

    if args.normalize:
        transforms = torch.nn.Sequential(transforms, Normalize())

    augmentations = torch.nn.Sequential()
    if args.freq_mask:
        augmentations = torch.nn.Sequential(
            augmentations,
            torchaudio.transforms.FrequencyMasking(
                freq_mask_param=args.freq_mask),
        )
    if args.time_mask:
        augmentations = torch.nn.Sequential(
            augmentations,
            torchaudio.transforms.TimeMasking(time_mask_param=args.time_mask),
        )

    # Text preprocessing

    char_blank = "*"
    char_space = " "
    char_apostrophe = "'"
    labels = char_blank + char_space + char_apostrophe + string.ascii_lowercase
    language_model = LanguageModel(labels, char_blank, char_space)

    # Dataset

    training, validation = split_process_librispeech(
        [args.dataset_train, args.dataset_valid],
        [transforms, transforms],
        language_model,
        root=args.dataset_root,
        folder_in_archive=args.dataset_folder_in_archive,
    )

    # Decoder

    if args.decoder == "greedy":
        decoder = GreedyDecoder()
    else:
        raise ValueError("Selected decoder not supported")

    # Model

    model = Wav2Letter(
        num_classes=language_model.length,
        input_type=args.type,
        num_features=num_features,
    )

    if args.jit:
        model = torch.jit.script(model)

    if args.distributed:
        n = torch.cuda.device_count() // args.world_size
        devices = list(range(rank * n, (rank + 1) * n))
        model = model.to(devices[0])
        model = torch.nn.parallel.DistributedDataParallel(model,
                                                          device_ids=devices)
    else:
        devices = ["cuda" if torch.cuda.is_available() else "cpu"]
        model = model.to(devices[0], non_blocking=True)
        model = torch.nn.DataParallel(model)

    n = count_parameters(model)
    logging.info("Number of parameters: %s", n)

    # Optimizer

    if args.optimizer == "adadelta":
        optimizer = Adadelta(
            model.parameters(),
            lr=args.learning_rate,
            weight_decay=args.weight_decay,
            eps=args.eps,
            rho=args.rho,
        )
    elif args.optimizer == "sgd":
        optimizer = SGD(
            model.parameters(),
            lr=args.learning_rate,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )
    elif args.optimizer == "adam":
        optimizer = Adam(
            model.parameters(),
            lr=args.learning_rate,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )
    elif args.optimizer == "adamw":
        optimizer = AdamW(
            model.parameters(),
            lr=args.learning_rate,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )
    else:
        raise ValueError("Selected optimizer not supported")

    if args.scheduler == "exponential":
        scheduler = ExponentialLR(optimizer, gamma=args.gamma)
    elif args.scheduler == "reduceonplateau":
        scheduler = ReduceLROnPlateau(optimizer, patience=10, threshold=1e-3)
    else:
        raise ValueError("Selected scheduler not supported")

    criterion = torch.nn.CTCLoss(blank=language_model.mapping[char_blank],
                                 zero_infinity=False)

    # Data Loader

    collate_fn_train = collate_factory(model_length_function, augmentations)
    collate_fn_valid = collate_factory(model_length_function)

    loader_training_params = {
        "num_workers": args.workers,
        "pin_memory": True,
        "shuffle": True,
        "drop_last": True,
    }
    loader_validation_params = loader_training_params.copy()
    loader_validation_params["shuffle"] = False

    loader_training = DataLoader(
        training,
        batch_size=args.batch_size,
        collate_fn=collate_fn_train,
        **loader_training_params,
    )
    loader_validation = DataLoader(
        validation,
        batch_size=args.batch_size,
        collate_fn=collate_fn_valid,
        **loader_validation_params,
    )

    # Setup checkpoint

    best_loss = 1.0

    load_checkpoint = args.checkpoint and os.path.isfile(args.checkpoint)

    if args.distributed:
        torch.distributed.barrier()

    if load_checkpoint:
        logging.info("Checkpoint: loading %s", args.checkpoint)
        checkpoint = torch.load(args.checkpoint)

        args.start_epoch = checkpoint["epoch"]
        best_loss = checkpoint["best_loss"]

        model.load_state_dict(checkpoint["state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        scheduler.load_state_dict(checkpoint["scheduler"])

        logging.info("Checkpoint: loaded '%s' at epoch %s", args.checkpoint,
                     checkpoint["epoch"])
    else:
        logging.info("Checkpoint: not found")

        save_checkpoint(
            {
                "epoch": args.start_epoch,
                "state_dict": model.state_dict(),
                "best_loss": best_loss,
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
            },
            False,
            args.checkpoint,
            not_main_rank,
        )

    if args.distributed:
        torch.distributed.barrier()

    torch.autograd.set_detect_anomaly(False)

    for epoch in range(args.start_epoch, args.epochs):

        logging.info("Epoch: %s", epoch)

        train_one_epoch(
            model,
            criterion,
            optimizer,
            scheduler,
            loader_training,
            decoder,
            language_model,
            devices[0],
            epoch,
            args.clip_grad,
            not_main_rank,
            not args.reduce_lr_valid,
        )

        loss = evaluate(
            model,
            criterion,
            loader_validation,
            decoder,
            language_model,
            devices[0],
            epoch,
            not_main_rank,
        )

        if args.reduce_lr_valid and isinstance(scheduler, ReduceLROnPlateau):
            scheduler.step(loss)

        is_best = loss < best_loss
        best_loss = min(loss, best_loss)
        save_checkpoint(
            {
                "epoch": epoch + 1,
                "state_dict": model.state_dict(),
                "best_loss": best_loss,
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
            },
            is_best,
            args.checkpoint,
            not_main_rank,
        )

    logging.info("End time: %s", datetime.now())

    if args.distributed:
        torch.distributed.destroy_process_group()
Exemple #5
0
class TrackerSiamFC(Tracker):
    def __init__(self, net_path=None, **kargs):
        super(TrackerSiamFC, self).__init__(name='SiamFC',
                                            is_deterministic=True)
        self.cfg = self.parse_args(**kargs)

        # setup GPU device if available
        self.cuda = torch.cuda.is_available()
        self.device = torch.device('cuda:0' if self.cuda else 'cpu')

        # setup model
        self.net = SiamFC()
        if net_path is not None:
            self.net.load_state_dict(
                torch.load(net_path,
                           map_location=lambda storage, loc: storage))
        self.net = self.net.to(self.device)

        # setup optimizer
        self.optimizer = optim.SGD(self.net.parameters(),
                                   lr=self.cfg.initial_lr,
                                   weight_decay=self.cfg.weight_decay,
                                   momentum=self.cfg.momentum)

        # setup lr scheduler
        self.lr_scheduler = ExponentialLR(self.optimizer,
                                          gamma=self.cfg.lr_decay)

    def parse_args(self, **kargs):
        # default parameters
        cfg = {
            # inference parameters
            'exemplar_sz': 127,
            'instance_sz': 255,
            'context': 0.5,
            'scale_num': 3,
            'scale_step': 1.0375,
            'scale_lr': 0.59,
            'scale_penalty': 0.9745,
            'window_influence': 0.176,
            'response_sz': 17,
            'response_up': 16,
            'total_stride': 8,
            'adjust_scale': 0.001,
            # train parameters
            'initial_lr': 0.01,
            'lr_decay': 0.8685113737513527,
            'weight_decay': 5e-4,
            'momentum': 0.9,
            'r_pos': 16,
            'r_neg': 0
        }

        for key, val in kargs.items():
            if key in cfg:
                cfg.update({key: val})
        return namedtuple('GenericDict', cfg.keys())(**cfg)

    def init(self, image, box):
        image = np.asarray(image)

        # convert box to 0-indexed and center based [cy, cx, h, w]
        box = np.array([
            box[1] - 1 + (box[3] - 1) / 2, box[0] - 1 +
            (box[2] - 1) / 2, box[3], box[2]
        ],
                       dtype=np.float32)
        self.center, self.target_sz = box[:2], box[2:]

        # create hanning window
        self.upscale_sz = self.cfg.response_up * self.cfg.response_sz
        self.hann_window = np.outer(np.hanning(self.upscale_sz),
                                    np.hanning(self.upscale_sz))
        self.hann_window /= self.hann_window.sum()

        # search scale factors
        self.scale_factors = self.cfg.scale_step**np.linspace(
            -(self.cfg.scale_num // 2), self.cfg.scale_num // 2,
            self.cfg.scale_num)

        # exemplar and search sizes
        context = self.cfg.context * np.sum(self.target_sz)
        self.z_sz = np.sqrt(np.prod(self.target_sz + context))
        self.x_sz = self.z_sz * \
            self.cfg.instance_sz / self.cfg.exemplar_sz

        # exemplar image
        self.avg_color = np.mean(image, axis=(0, 1))
        exemplar_image = self._crop_and_resize(image,
                                               self.center,
                                               self.z_sz,
                                               out_size=self.cfg.exemplar_sz,
                                               pad_color=self.avg_color)

        gt_img = image[int(box[0] - box[2] / 2):int(box[0] + box[2] / 2),
                       int(box[1] - box[3] / 2):int(box[1] + box[3] / 2)]
        self.gt_img = gt_img  #added by gtz

        # exemplar features
        exemplar_image = torch.from_numpy(exemplar_image).to(
            self.device).permute([2, 0, 1]).unsqueeze(0).float()
        with torch.set_grad_enabled(False):
            self.net.eval()
            self.kernel = self.net.feature(exemplar_image)
        self.init_box = box  #added by gtz

    def update(self, image):
        image = np.asarray(image)

        # search images
        instance_images = [
            self._crop_and_resize(image,
                                  self.center,
                                  self.x_sz * f,
                                  out_size=self.cfg.instance_sz,
                                  pad_color=self.avg_color)
            for f in self.scale_factors
        ]
        instance_images = np.stack(instance_images, axis=0)
        instance_images = torch.from_numpy(instance_images).to(
            self.device).permute([0, 3, 1, 2]).float()

        # responses
        with torch.set_grad_enabled(False):
            self.net.eval()
            instances = self.net.feature(instance_images)
            responses = F.conv2d(instances, self.kernel) * 0.001
        responses = responses.squeeze(1).cpu().numpy()

        # upsample responses and penalize scale changes
        responses = np.stack([
            cv2.resize(t, (self.upscale_sz, self.upscale_sz),
                       interpolation=cv2.INTER_CUBIC) for t in responses
        ],
                             axis=0)
        responses[:self.cfg.scale_num // 2] *= self.cfg.scale_penalty
        responses[self.cfg.scale_num // 2 + 1:] *= self.cfg.scale_penalty

        # peak scale
        scale_id = np.argmax(np.amax(responses, axis=(1, 2)))

        # peak location
        response = responses[scale_id]
        response -= response.min()
        response /= response.sum() + 1e-16
        response = (1 - self.cfg.window_influence) * response + \
            self.cfg.window_influence * self.hann_window
        loc = np.unravel_index(response.argmax(),
                               response.shape)  #tuple (138,136)

        # locate target center
        disp_in_response = np.array(loc) - self.upscale_sz // 2
        disp_in_instance = disp_in_response * \
            self.cfg.total_stride / self.cfg.response_up
        disp_in_image = disp_in_instance * self.x_sz * \
            self.scale_factors[scale_id] / self.cfg.instance_sz
        self.center += disp_in_image

        # update target size
        scale =  (1 - self.cfg.scale_lr) * 1.0 + \
            self.cfg.scale_lr * self.scale_factors[scale_id]
        self.target_sz *= scale
        self.z_sz *= scale
        self.x_sz *= scale

        # return 1-indexed and left-top based bounding box
        box = np.array([
            self.center[1] + 1 - (self.target_sz[1] - 1) / 2,
            self.center[0] + 1 - (self.target_sz[0] - 1) / 2,
            self.target_sz[1], self.target_sz[0]
        ])

        return box

    def step(self, batch, backward=True, update_lr=False):
        if backward:
            self.net.train()
            if update_lr:
                self.lr_scheduler.step()
        else:
            self.net.eval()

        z = batch[0].to(self.device)
        x = batch[1].to(self.device)

        with torch.set_grad_enabled(backward):
            responses = self.net(z, x)
            labels, weights = self._create_labels(responses.size())
            loss = F.binary_cross_entropy_with_logits(responses,
                                                      labels,
                                                      weight=weights,
                                                      size_average=True)

            if backward:
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

        return loss.item()

    def _crop_and_resize(self, image, center, size, out_size, pad_color):
        # convert box to corners (0-indexed)
        size = round(size)
        corners = np.concatenate((np.round(center - (size - 1) / 2),
                                  np.round(center - (size - 1) / 2) + size))
        corners = np.round(corners).astype(int)

        # pad image if necessary
        pads = np.concatenate((-corners[:2], corners[2:] - image.shape[:2]))
        npad = max(0, int(pads.max()))
        if npad > 0:
            image = cv2.copyMakeBorder(image,
                                       npad,
                                       npad,
                                       npad,
                                       npad,
                                       cv2.BORDER_CONSTANT,
                                       value=pad_color)

        # crop image patch
        corners = (corners + npad).astype(int)
        patch = image[corners[0]:corners[2], corners[1]:corners[3]]

        # resize to out_size
        patch = cv2.resize(patch, (out_size, out_size))

        return patch

    def _create_labels(self, size):
        # skip if same sized labels already created
        if hasattr(self, 'labels') and self.labels.size() == size:
            return self.labels, self.weights

        def logistic_labels(x, y, r_pos, r_neg):
            dist = np.abs(x) + np.abs(y)  # block distance
            labels = np.where(
                dist <= r_pos, np.ones_like(x),
                np.where(dist < r_neg,
                         np.ones_like(x) * 0.5, np.zeros_like(x)))
            return labels

        # distances along x- and y-axis
        n, c, h, w = size
        x = np.arange(w) - w // 2
        y = np.arange(h) - h // 2
        x, y = np.meshgrid(x, y)

        # create logistic labels
        r_pos = self.cfg.r_pos / self.cfg.total_stride
        r_neg = self.cfg.r_neg / self.cfg.total_stride
        labels = logistic_labels(x, y, r_pos, r_neg)

        # pos/neg weights
        pos_num = np.sum(labels == 1)
        neg_num = np.sum(labels == 0)
        weights = np.zeros_like(labels)
        weights[labels == 1] = 0.5 / pos_num
        weights[labels == 0] = 0.5 / neg_num
        weights *= pos_num + neg_num

        # repeat to size
        labels = labels.reshape((1, 1, h, w))
        weights = weights.reshape((1, 1, h, w))
        labels = np.tile(labels, (n, c, 1, 1))
        weights = np.tile(weights, [n, c, 1, 1])

        # convert to tensors
        self.labels = torch.from_numpy(labels).to(self.device).float()
        self.weights = torch.from_numpy(weights).to(self.device).float()

        return self.labels, self.weights

    def visualize(self, img_files, box):
        frame_num = len(img_files)
        boxes = np.zeros((frame_num, 4))
        boxes[0] = box
        times = np.zeros(frame_num)

        for f, img_file in enumerate(img_files):
            image = Image.open(img_file)
            if not image.mode == 'RGB':
                image = image.convert('RGB')

            start_time = time.time()
            if f == 0:
                self.init(image, box)
            else:
                boxes[f, :] = self.update(image)  # x,y,w,h
                heatmap, response = self.update_heatmap(image)
                colormap, Prob1 = self.update_colormap(image)
                mixedmap = self.update_mixedmap(image.size, response, Prob1)
                image = np.asarray(image)
                image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
                x1, y1, w, h = boxes[f, :]
                x1 = int(x1)
                y1 = int(y1)
                x2 = int(x1 + w)
                y2 = int(y1 + h)
                cv2.rectangle(image, (x1, y1), (x2, y2), (0, 0, 255), 2)
            times[f] = time.time() - start_time
        return boxes, times

    def visualize1(self, img_files, gt_boxes, opt, model):
        for f, img_file in enumerate(img_files):
            image = Image.open(img_file)
            if not image.mode == 'RGB':
                image = image.convert('RGB')
            box = gt_boxes[f]
            box = np.array([
                box[1] - 1 + (box[3] - 1) / 2, box[0] - 1 +
                (box[2] - 1) / 2, box[3], box[2]
            ],
                           dtype=np.float32)
            center, target_sz = box[:2], box[2:]
            search_size = 2 * np.max(target_sz)
            start_time = time.time()
            patch = self.my_crop(image, center, search_size)
            patch = np.asarray(patch)
            # image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

            candidates = detect(opt, model, patch)[0]

            self.save_imgs(opt, candidates, patch, target_sz, f)

    def save_imgs(self, opt, candidates, image, target_sz, f):
        # Bounding-box colors
        color = (0.6, 0.3, 0.3, 1)

        print("\nSaving images:")

        # Create plot

        plt.figure()
        fig, ax = plt.subplots(1)
        ax.imshow(image)

        # Draw bounding boxes and labels of detections
        if candidates is not None:
            # Rescale boxes to original image
            candidates = rescale_boxes(candidates, opt.img_size,
                                       image.shape[:2])
            candidate_sz = candidates[:, 2:4] - candidates[:, :2]

            target_sz1 = torch.Tensor([target_sz[1],
                                       target_sz[0]]).expand_as(candidate_sz)

            # double filtering
            valid_mask_1 = ~(torch.any(
                candidate_sz > opt.one_scale_thres * target_sz1, dim=1)
                             | torch.any(candidate_sz <
                                         1 / opt.one_scale_thres * target_sz1,
                                         dim=1))
            valid_mask_2 = ~(torch.all(
                candidate_sz > opt.two_scale_thres * target_sz1, dim=1)
                             | torch.all(candidate_sz <
                                         1 / opt.two_scale_thres * target_sz1,
                                         dim=1))
            valid_mask = valid_mask_1 & valid_mask_2

            candidates = candidates[valid_mask]

            for x1, y1, x2, y2, conf in candidates:
                box_w = x2 - x1
                box_h = y2 - y1

                # Create a Rectangle patch
                bbox = patches.Rectangle((x1, y1),
                                         box_w,
                                         box_h,
                                         linewidth=2,
                                         edgecolor=color,
                                         facecolor="none")
                # Add the bbox to the plot
                ax.add_patch(bbox)

        # Save generated image with detections
        plt.axis("off")
        plt.gca().xaxis.set_major_locator(NullLocator())
        plt.gca().yaxis.set_major_locator(NullLocator())

        plt.savefig(f"my_output1/{f}.png", bbox_inches="tight", pad_inches=0.0)
        plt.close()

    def my_crop(self, image, center, size):
        image = np.array(image)
        # convert box to corners (0-indexed)
        size = round(size)
        corners = np.concatenate((np.round(center - (size - 1) / 2),
                                  np.round(center - (size - 1) / 2) + size))
        corners = np.round(corners).astype(int)

        # pad image if necessary
        pads = np.concatenate((-corners[:2], corners[2:] - image.shape[:2]))
        npad = max(0, int(pads.max()))
        if npad > 0:
            image = cv2.copyMakeBorder(image,
                                       npad,
                                       npad,
                                       npad,
                                       npad,
                                       cv2.BORDER_CONSTANT,
                                       value=(255, 255, 255))

        # crop image patch
        corners = (corners + npad).astype(int)
        patch = image[corners[0]:corners[2], corners[1]:corners[3]]

        return patch

    def update_colormap(self, image):
        image = np.asarray(image)
        hist_obj = cv2.calcHist([self.gt_img], [0, 1, 2], None, [10, 10, 10],
                                [0, 256, 0, 256, 0, 256])
        # cv2.imshow('gt_img',self.gt_img)
        # cv2.waitKey(0)
        hist_img = cv2.calcHist([image], [0, 1, 2], None, [10, 10, 10],
                                [0, 256, 0, 256, 0, 256])

        image_10 = image / 25.6  # histogram has 10 bins
        image_10 = image_10.astype('uint8')
        # creating a likelihood image acc. to obj-surr or obj-distractor model
        a = image_10[:, :, 0]
        a = a.ravel()
        b = image_10[:, :, 1]
        b = b.ravel()
        c_ = image_10[:, :, 2]
        c_ = c_.ravel()
        H_obj = hist_obj[
            a, b,
            c_]  # image with pixel value=bin count of the pixel value at the same location in original image
        H_img = hist_img[a, b, c_]
        Prob1 = np.zeros((image.shape[0] * image.shape[1], ), dtype='float')
        H_obj = H_obj.astype('float')
        H_img = H_img.astype('float')
        mask = H_img == 0
        # print mask,"check itjhjnkjkjkjk"
        Prob1[~mask] = np.divide(H_obj[~mask], H_img[~mask])
        Prob1[mask] = 0.1
        Prob1 = Prob1.reshape((image.shape[0], image.shape[1]))
        Prob2 = (Prob1) * 255
        Prob2 = Prob2.astype('uint8')
        likemap = cv2.applyColorMap(Prob2, cv2.COLORMAP_JET)
        return likemap, Prob1

    def update_mixedmap(self, shape, response, Prob1):
        #response is already normalized
        # Prob1-=Prob1.min()  # Prob1 -> (432,576)
        # Prob1/=Prob1.sum()+1e-16

        #response = np.expand_dims(response, -1).astype('uint8')  # (272,272) -> (272,272,1)
        response = cv2.resize(response, (shape[0], shape[1]),
                              interpolation=cv2.INTER_CUBIC)

        enhance_mask = Prob1 > 0.5
        slight_reduce_mask = Prob1 < 1e-2
        remove_mask = Prob1 < 1e-4

        #mixedmap=response-reduce_mask*0.5e-5
        #mixedmap=response+enhance_mask*1e-6-slight_reduce_mask*8e-6-remove_mask*2e-5
        mixedmap = response - slight_reduce_mask * 1e-5 - remove_mask * 5e-5
        mixedmap = np.clip(mixedmap, 0, 1)
        m_max = np.max(mixedmap)
        m_min = np.min(mixedmap)
        mixedmap = (mixedmap - m_min) / (m_max - m_min) * 255.
        mixedmap = mixedmap.astype('uint8')
        mixedmap = cv2.applyColorMap(mixedmap, cv2.COLORMAP_JET)
        return mixedmap
Exemple #6
0
class TrackerSiamFC(Tracker):
    def __init__(self, net_path=None, **kargs):
        super(TrackerSiamFC, self).__init__(name='SiamFC',
                                            is_deterministic=True)
        self.cfg = self.parse_args(**kargs)

        # setup GPU device if available
        self.cuda = torch.cuda.is_available()
        self.device = torch.device('cuda:0' if self.cuda else 'cpu')

        # setup model
        self.net = SiamFC()
        if net_path is not None:
            self.net.load_state_dict(
                torch.load(net_path,
                           map_location=lambda storage, loc: storage))
        self.net = self.net.to(self.device)

        # setup optimizer
        self.optimizer = optim.SGD(self.net.parameters(),
                                   lr=self.cfg.initial_lr,
                                   weight_decay=self.cfg.weight_decay,
                                   momentum=self.cfg.momentum)

        # setup lr scheduler
        self.lr_scheduler = ExponentialLR(self.optimizer,
                                          gamma=self.cfg.lr_decay)

    def parse_args(self, **kargs):
        # default parameters
        cfg = {
            # inference parameters
            'exemplar_sz': 135,
            'instance_sz': 263,
            'context': 0.5,
            'scale_num': 3,
            'scale_step': 1.0375,
            'scale_lr': 0.59,
            'scale_penalty': 0.9745,
            'window_influence': 0.27,  #0.176,
            'response_sz': 17,
            'response_up': 16,
            'total_stride': 8,
            'adjust_scale': 0.001,
            # train parameters
            'initial_lr': 0.01,
            'lr_decay': 0.8685113737513527,
            'weight_decay': 5e-4,
            'momentum': 0.9,
            'r_pos': 16,
            'r_neg': 0
        }

        for key, val in kargs.items():
            if key in cfg:
                cfg.update({key: val})
        return namedtuple('GenericDict', cfg.keys())(**cfg)

    def init(self, image, box):
        image = np.asarray(image)

        # convert box to 0-indexed and center based [y, x, h, w]
        box = np.array([
            box[1] - 1 + (box[3] - 1) / 2, box[0] - 1 +
            (box[2] - 1) / 2, box[3], box[2]
        ],
                       dtype=np.float32)
        self.center, self.target_sz = box[:2], box[2:]

        # create hanning window
        self.upscale_sz = self.cfg.response_up * self.cfg.response_sz
        self.hann_window = np.outer(np.hanning(self.upscale_sz),
                                    np.hanning(self.upscale_sz))
        self.hann_window /= self.hann_window.sum()

        # search scale factors
        self.scale_factors = self.cfg.scale_step**np.linspace(
            -(self.cfg.scale_num // 2), self.cfg.scale_num // 2,
            self.cfg.scale_num)

        # exemplar and search sizes
        context = self.cfg.context * np.sum(self.target_sz)
        self.z_sz = np.sqrt(np.prod(self.target_sz + context))
        self.x_sz = self.z_sz * \
            self.cfg.instance_sz / self.cfg.exemplar_sz

        # exemplar image
        self.avg_color = np.mean(image, axis=(0, 1))
        exemplar_image = self._crop_and_resize(image,
                                               self.center,
                                               self.z_sz,
                                               out_size=self.cfg.exemplar_sz,
                                               pad_color=self.avg_color)

        # exemplar features
        exemplar_image = torch.from_numpy(exemplar_image).to(
            self.device).permute([2, 0, 1]).unsqueeze(0).float()
        with torch.set_grad_enabled(False):
            self.net.eval()

            z = self.net.feature1(exemplar_image)
            z_noise = self.net.feature2(exemplar_image)
            z = torch.add(z, z_noise)
            self.kernel = self.net.feature3(z)

    def update(self, image):
        image = np.asarray(image)

        # search images
        instance_images = [
            self._crop_and_resize(image,
                                  self.center,
                                  self.x_sz * f,
                                  out_size=self.cfg.instance_sz,
                                  pad_color=self.avg_color)
            for f in self.scale_factors
        ]
        instance_images = np.stack(instance_images, axis=0)
        instance_images = torch.from_numpy(instance_images).to(
            self.device).permute([0, 3, 1, 2]).float()

        # responses
        with torch.set_grad_enabled(False):
            self.net.eval()

            x = self.net.feature1(instance_images)
            x_noise = self.net.feature2(instance_images)
            x = torch.add(x, x_noise)
            instances = self.net.feature3(x)

            responses = F.conv2d(instances, self.kernel) * 0.001
        responses = responses.squeeze(1).cpu().numpy()

        # upsample responses and penalize scale changes
        responses = np.stack([
            cv2.resize(t, (self.upscale_sz, self.upscale_sz),
                       interpolation=cv2.INTER_CUBIC) for t in responses
        ],
                             axis=0)
        responses[:self.cfg.scale_num // 2] *= self.cfg.scale_penalty
        responses[self.cfg.scale_num // 2 + 1:] *= self.cfg.scale_penalty

        # peak scale
        scale_id = np.argmax(np.amax(responses, axis=(1, 2)))

        # peak location
        response = responses[scale_id]
        response -= response.min()
        response /= response.sum() + 1e-16
        response = (1 - self.cfg.window_influence) * response + \
            self.cfg.window_influence * self.hann_window
        loc = np.unravel_index(response.argmax(), response.shape)

        # locate target center
        disp_in_response = np.array(loc) - self.upscale_sz // 2
        disp_in_instance = disp_in_response * \
            self.cfg.total_stride / self.cfg.response_up
        disp_in_image = disp_in_instance * self.x_sz * \
            self.scale_factors[scale_id] / self.cfg.instance_sz
        self.center += disp_in_image

        # update target size
        scale =  (1 - self.cfg.scale_lr) * 1.0 + \
            self.cfg.scale_lr * self.scale_factors[scale_id]
        self.target_sz *= scale
        self.z_sz *= scale
        self.x_sz *= scale

        # return 1-indexed and left-top based bounding box
        box = np.array([
            self.center[1] + 1 - (self.target_sz[1] - 1) / 2,
            self.center[0] + 1 - (self.target_sz[0] - 1) / 2,
            self.target_sz[1], self.target_sz[0]
        ])

        return box

    def step(self, batch, backward=True, update_lr=False):
        if backward:
            self.net.train()
            if update_lr:
                self.lr_scheduler.step()
        else:
            self.net.eval()

        z = batch[0].to(self.device)
        z_noise = batch[1].to(self.device)

        x = batch[2].to(self.device)
        x_noise = batch[3].to(self.device)

        with torch.set_grad_enabled(backward):
            responses = self.net(z, z_noise, x, x_noise)
            labels, weights = self._create_labels(responses.size())
            loss = F.binary_cross_entropy_with_logits(responses,
                                                      labels,
                                                      weight=weights,
                                                      size_average=True)

            if backward:
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
        #print(loss.item())
        return loss.item()

    def _crop_and_resize(self, image, center, size, out_size, pad_color):
        # convert box to corners (0-indexed)
        size = round(size)
        corners = np.concatenate((np.round(center - (size - 1) / 2),
                                  np.round(center - (size - 1) / 2) + size))
        corners = np.round(corners).astype(int)

        # pad image if necessary
        pads = np.concatenate((-corners[:2], corners[2:] - image.shape[:2]))
        npad = max(0, int(pads.max()))
        if npad > 0:
            image = cv2.copyMakeBorder(image,
                                       npad,
                                       npad,
                                       npad,
                                       npad,
                                       cv2.BORDER_CONSTANT,
                                       value=pad_color)

        # crop image patch
        corners = (corners + npad).astype(int)
        patch = image[corners[0]:corners[2], corners[1]:corners[3]]

        # resize to out_size
        patch = cv2.resize(patch, (out_size, out_size))

        return patch

    def _create_labels(self, size):
        # skip if same sized labels already created
        if hasattr(self, 'labels') and self.labels.size() == size:
            return self.labels, self.weights

        def logistic_labels(x, y, r_pos, r_neg):
            dist = np.abs(x) + np.abs(y)  # block distance
            labels = np.where(
                dist <= r_pos, np.ones_like(x),
                np.where(dist < r_neg,
                         np.ones_like(x) * 0.5, np.zeros_like(x)))
            return labels

        # distances along x- and y-axis
        n, c, h, w = size
        x = np.arange(w) - w // 2
        y = np.arange(h) - h // 2
        x, y = np.meshgrid(x, y)

        # create logistic labels
        r_pos = self.cfg.r_pos / self.cfg.total_stride
        r_neg = self.cfg.r_neg / self.cfg.total_stride
        labels = logistic_labels(x, y, r_pos, r_neg)

        # pos/neg weights
        pos_num = np.sum(labels == 1)
        neg_num = np.sum(labels == 0)
        weights = np.zeros_like(labels)
        weights[labels == 1] = 0.5 / pos_num
        weights[labels == 0] = 0.5 / neg_num
        weights *= pos_num + neg_num

        # repeat to size
        labels = labels.reshape((1, 1, h, w))
        weights = weights.reshape((1, 1, h, w))
        labels = np.tile(labels, (n, c, 1, 1))
        weights = np.tile(weights, [n, c, 1, 1])

        # convert to tensors
        self.labels = torch.from_numpy(labels).to(self.device).float()
        self.weights = torch.from_numpy(weights).to(self.device).float()

        return self.labels, self.weights
Exemple #7
0
def run(train_loader, val_loader, epochs, lr, momentum, weight_decay, lr_step,
        k1, k2, es_patience, log_dir):
    model = Vgg16()

    device = 'cpu'
    if torch.cuda.is_available():
        device = 'cuda'
    model.to(device)

    optimizer = optim.SGD(model.parameters(),
                          lr=lr,
                          momentum=momentum,
                          weight_decay=weight_decay)

    lr_scheduler = ExponentialLR(optimizer, gamma=0.975)

    # criterion = VAELoss(k1=k1, k2=k2).to(device)

    def update_fn(engine, batch):
        x, y = _prepare_batch(batch, device=device, non_blocking=True)

        model.train()

        optimizer.zero_grad()

        output = model(x)

        # Compute loss
        loss = F.nll_loss(output, y)

        loss.backward()

        optimizer.step()

        return {
            "batchloss": loss.item(),
        }

    trainer = Engine(update_fn)

    try:
        GpuInfo().attach(trainer)
    except RuntimeError:
        print(
            "INFO: By default, in this example it is possible to log GPU information (used memory, utilization). "
            "As there is no pynvml python package installed, GPU information won't be logged. Otherwise, please "
            "install it : `pip install pynvml`")

    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=lr_step),
                              lambda engine: lr_scheduler.step())

    metric_names = [
        'batchloss',
    ]

    def output_transform(x, name):
        return x[name]

    for n in metric_names:
        # We compute running average values on the output (batch loss) across all devices
        RunningAverage(output_transform=partial(output_transform, name=n),
                       epoch_bound=False,
                       device=device).attach(trainer, n)

    exp_name = datetime.now().strftime("%Y%m%d-%H%M%S")
    log_path = log_dir + "/vgg_vae/{}".format(exp_name)

    tb_logger = TensorboardLogger(log_dir=log_path)

    tb_logger.attach(trainer,
                     log_handler=OutputHandler(tag="training",
                                               metric_names=metric_names),
                     event_name=Events.ITERATION_COMPLETED)

    tb_logger.attach(trainer,
                     log_handler=OptimizerParamsHandler(optimizer, "lr"),
                     event_name=Events.ITERATION_STARTED)

    ProgressBar(persist=True,
                bar_format="").attach(trainer,
                                      event_name=Events.EPOCH_STARTED,
                                      closing_event_name=Events.COMPLETED)
    ProgressBar(persist=False, bar_format="").attach(trainer,
                                                     metric_names=metric_names)

    # val process definition
    def loss_output_transform(output):
        return output

    def acc_output_transform(output):
        return output

    customed_loss = Loss(loss_fn=F.nll_loss,
                         output_transform=loss_output_transform,
                         device=device)
    customed_accuracy = Accuracy(output_transform=acc_output_transform,
                                 device=device)

    metrics = {'Loss': customed_loss, 'Accuracy': customed_accuracy}

    def val_update_fn(engine, batch):
        model.eval()
        with torch.no_grad():
            x, y = _prepare_batch(batch, device=device, non_blocking=True)
            output = model(x)
            return output, y

    val_evaluator = Engine(val_update_fn)

    for name, metric in metrics.items():
        metric.attach(val_evaluator, name)

    def run_evaluation(engine):
        val_evaluator.run(val_loader)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, run_evaluation)
    trainer.add_event_handler(Events.COMPLETED, run_evaluation)

    ProgressBar(persist=False, desc="Train evaluation").attach(val_evaluator)

    # Log val metrics:
    tb_logger.attach(val_evaluator,
                     log_handler=OutputHandler(tag="val",
                                               metric_names=list(
                                                   metrics.keys()),
                                               another_engine=trainer),
                     event_name=Events.EPOCH_COMPLETED)

    # trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

    # Store the best model
    def default_score_fn(engine):
        score = engine.state.metrics['Accuracy']
        return score

    best_model_handler = ModelCheckpoint(dirname=log_path,
                                         filename_prefix="best",
                                         n_saved=3,
                                         score_name="val_acc",
                                         score_function=default_score_fn)
    val_evaluator.add_event_handler(Events.COMPLETED, best_model_handler, {
        'model': model,
    })

    # Add early stopping
    es_patience = es_patience
    es_handler = EarlyStopping(patience=es_patience,
                               score_function=default_score_fn,
                               trainer=trainer)
    val_evaluator.add_event_handler(Events.COMPLETED, es_handler)

    setup_logger(es_handler._logger)
    setup_logger(logging.getLogger("ignite.engine.engine.Engine"))

    def empty_cuda_cache(engine):
        torch.cuda.empty_cache()
        import gc
        gc.collect()

    trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache)
    val_evaluator.add_event_handler(Events.COMPLETED, empty_cuda_cache)

    trainer.run(train_loader, max_epochs=epochs)
Exemple #8
0
            # optimizer.zero_grad()
            # loss.backward()
            # optimizer.step()
            #
            # losses3 += loss.item()

            if local_rank == 0:
                total = i + 1
                tbar.set_description(
                    'epoch: %d, loss manual mining: %.3f, loss hard mining: %.3f'
                    % (epoch + 1, losses_manual_mining / total,
                       losses_hard_mining / total))

            # tbar.set_description('epoch: %d, loss1: %.3f, loss2: %.3f'
            #                      % (epoch + 1, losses_manual_mining / (i + 1), losses_hard_mining / (i + 1)))
        scheduler.step(epoch)
        if local_rank == 0:
            checkpoints = {
                'query': text_encoder.module.state_dict(),
                'item': image_encoder.module.state_dict(),
                'score': score_model.module.state_dict(),
                # 'category': category_embedding.state_dict(),
                # 'generator': text_generator.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(
                checkpoints,
                os.path.join(checkpoints_dir,
                             'model-epoch{}.pth'.format(epoch + 1)))
            # score_model.eval()
            # text_encoder.eval()
    def train_and_eval(self):
        print("Training the %s model..." % self.model_name)

        outfolder = "/afs/inf.ed.ac.uk/group/project/Knowledge_Bases/stats/"
        fname = "result_%dmil_%s_ws%d_w%d_c%d_d%d_" % (int(np.ceil(d.cutoff/1e6)), 
                            self.model_name, d.window_size, int(self.w_reg*10), 
                            int(self.c_reg*10), self.embeddings_dim)

        if self.model_name.lower() == "p2v-l":
            model = P2VL(self.embeddings_dim)
        elif self.model_name.lower() == "p2v-p":
            model = P2VP(self.embeddings_dim)

        if self.cuda:
            model.cuda()

        opt = torch.optim.Adam(model.parameters(), lr=self.learning_rate)
        if self.decay_rate:
            scheduler = ExponentialLR(opt, self.decay_rate)

        data_idxs = range(len(d.word_counts))
        data_w_probs = [d.w_probs[i] for i in data_idxs]
        data_c_probs = [d.c_probs[i] for i in data_idxs]
        cooccurrences = list(d.cooccurrence_counts.keys())
        
        def pmi_ii(data_batch, i):
            return np.maximum(0.1, np.log([d.cooccurrence_counts.get((pair[i], pair[i]), min_p)/
                                           d.w_probs[pair[i]]**2 for pair in data_batch]))

        def pmi(data_batch):
            targets = [d.cooccurrence_counts.get(pair, 0.)/(d.w_probs[pair[0]]*
                       d.c_probs[pair[1]]) for pair in data_batch]
            return np.array([np.log(target) if target!=0 else -1. for target in targets])


        ww_cooccurrences = [d.cooccurrence_counts.get((idx, idx), 100.) for idx in d.w_probs]
        min_p = np.min(ww_cooccurrences)/1.5

        losses = []
        tick = time.time()
        counter = 0
        for i in range(1, self.num_iterations+1):
            model.train() 
            np.random.shuffle(cooccurrences)
            num_batches = int(np.ceil(len(cooccurrences)/float(self.batch_size)))
            num_neg_samples = self.batch_size*self.corrupt_size*num_batches
            all_neg_pairs = list(zip(np.random.choice(data_idxs, num_neg_samples, p=data_w_probs), 
                                     np.random.choice(data_idxs, num_neg_samples, p=data_c_probs)))
            
            epoch_loss = []
            for j in range(0, len(cooccurrences), self.batch_size):
                counter += 1
                pos_pairs = cooccurrences[j:min(j+self.batch_size, len(cooccurrences))]
                neg_pairs = all_neg_pairs[j*self.corrupt_size:j*self.corrupt_size+
                                          self.batch_size*self.corrupt_size]
                data_batch = pos_pairs + neg_pairs
                targets = torch.FloatTensor(pmi(data_batch))
                if self.model_name.lower() == "p2v-l":
                    targets_w = torch.FloatTensor(np.sqrt(pmi_ii(data_batch, 0)))
                    targets_c = torch.FloatTensor(np.sqrt(pmi_ii(data_batch, 1)))
                elif self.model_name.lower() == "p2v-p":
                    targets_w = torch.FloatTensor(pmi_ii(data_batch, 0))
                    targets_c = torch.FloatTensor(np.zeros(len(targets)))
                opt.zero_grad()
                data_batch = np.array(data_batch)
                w_idx = torch.tensor(data_batch[:,0])
                c_idx = torch.tensor(data_batch[:,1])
                if self.cuda:
                    w_idx = w_idx.cuda()
                    c_idx = c_idx.cuda()
                    targets = targets.cuda()
                    targets_w = targets_w.cuda()
                    targets_c = targets_c.cuda()
                preds, preds_w, preds_c = model.forward(w_idx, c_idx)
                loss = model.loss(preds, targets) +\
                       self.w_reg * model.loss(preds_w, targets_w) +\
                       self.c_reg * model.loss(preds_c, targets_c)
                loss.backward()
                opt.step()
                epoch_loss.append(loss.item())
                if self.decay_rate and not counter%500:
                    scheduler.step()
            print("Iteration: %d" % i)
            print("Loss: %.4f" % np.mean(epoch_loss))
            if not i%10:
                np_W = model.W.weight.detach().cpu().numpy()
                np_C = model.C.weight.detach().cpu().numpy()   
                if i == 10:
                    np.save("%s%s%d.npy" % (outfolder, fname, i), {"W":np_W, "C":np_C, 
                            "p_w":d.w_probs, "p_c":d.c_probs, "p_wc":d.cooccurrence_counts, 
                            "losses":losses, "i2w":d.idx_to_word})
                else:
                    np.save("%s%s%d.npy" % (outfolder, fname, i), {"W":np_W, "C":np_C, 
                            "losses":losses})
            losses.append(np.mean(epoch_loss))
        print("Time: ", str(time.time()-tick))
    def train_stand(self, train_data, valid_data, derived_struct, rela_cluster,
                    mrr):

        self.rela_to_dict(rela_cluster)

        #self.args.perf_file = os.path.join(self.args.out_dir, self.args.dataset + '_std_' + str(self.args.m) + "_" + str(self.args.n)  + "_" + str(mrr) + '.txt')
        #plot_config(self.args)

        head, tail, rela = train_data
        n_train = len(head)

        if self.args.optim == 'adam' or self.args.optim == 'Adam':
            self.optimizer = Adam(self.model.parameters(), lr=self.args.lr)
        elif self.args.optim == 'adagrad' or self.args.optim == 'Adagrad':
            self.optimizer = Adagrad(self.model.parameters(), lr=self.args.lr)
        else:
            self.optimizer = SGD(self.model.parameters(), lr=self.args.lr)
        scheduler = ExponentialLR(self.optimizer, self.args.decay_rate)

        n_batch = self.args.n_batch

        best_mrr = 0
        start = time.time()
        for epoch in range(self.args.n_stand_epoch):

            #self.epoch = epoch
            rand_idx = torch.randperm(n_train)

            if self.GPU:
                head = head[rand_idx].cuda()
                tail = tail[rand_idx].cuda()
                rela = rela[rand_idx].cuda()
            else:
                head = head[rand_idx]
                tail = tail[rand_idx]
                rela = rela[rand_idx]

            epoch_loss = 0
            n_iters = 0
            #lr = scheduler.get_lr()[0]

            # train model weights
            for h, t, r in batch_by_size(n_batch,
                                         head,
                                         tail,
                                         rela,
                                         n_sample=n_train):

                self.model.zero_grad()

                loss = self.model.forward(derived_struct, h, t, r,
                                          self.cluster_rela_dict)
                loss += self.args.lamb * self.model.regul
                loss.backward()

                self.optimizer.step()
                self.prox_operator()

                epoch_loss += loss.data.cpu().numpy()
                n_iters += 1

            scheduler.step()

            print("Epoch: %d/%d, Loss=%.2f, Stand Time=%.2f" %
                  (epoch + 1, self.args.n_stand_epoch, time.time() - start,
                   epoch_loss / n_train))

            if (epoch + 1) % 5 == 0:
                test, randint = True, None

                valid_mrr, valid_mr, valid_1, valid_3, valid_10 = self.tester_val(
                    derived_struct, test, randint)
                test_mrr, test_mr, test_1, test_3, test_10 = self.tester_tst(
                    derived_struct, test, randint)

                out_str = '%d \t %.2f \t %.2f \t %.4f  %.1f %.4f %.4f %.4f\t%.4f %.1f %.4f %.4f %.4f\n' % (epoch, self.time_tot, epoch_loss/n_train,\
                            valid_mrr, valid_mr, valid_1, valid_3, valid_10, \
                            test_mrr, test_mr, test_1, test_3, test_10)

                # output the best performance info
                if test_mrr > best_mrr:
                    best_mrr = test_mrr
                    best_str = out_str

                with open(self.args.perf_file, 'a+') as f:
                    f.write(out_str)

        with open(self.args.perf_file, 'a+') as f:
            f.write("best performance:" + best_str + "\n")
            f.write("struct:" + str(derived_struct) + "\n")
            f.write("rela:" + str(rela_cluster) + "\n")

        return best_mrr
    def train_oas(self, train_data, valid_data, derived_struct):

        head, tail, rela = train_data
        n_train = len(head)

        if self.args.optim == 'adam' or self.args.optim == 'Adam':
            self.optimizer = Adam(self.model.parameters(), lr=self.args.lr)
        elif self.args.optim == 'adagrad' or self.args.optim == 'Adagrad':
            self.optimizer = Adagrad(self.model.parameters(), lr=self.args.lr)
        else:
            self.optimizer = SGD(self.model.parameters(), lr=self.args.lr)

        scheduler = ExponentialLR(self.optimizer, self.args.decay_rate)

        n_batch = self.args.n_batch

        for epoch in range(self.args.n_oas_epoch):
            start = time.time()
            rand_idx = torch.randperm(n_train)

            if self.GPU:
                head = head[rand_idx].cuda()
                tail = tail[rand_idx].cuda()
                rela = rela[rand_idx].cuda()
            else:
                head = head[rand_idx]
                tail = tail[rand_idx]
                rela = rela[rand_idx]

            epoch_loss = 0
            n_iters = 0

            # train model weights
            for h, t, r in batch_by_size(n_batch,
                                         head,
                                         tail,
                                         rela,
                                         n_sample=n_train):

                self.model.zero_grad()

                loss = self.model.forward(derived_struct, h, t, r,
                                          self.cluster_rela_dict)
                loss += self.args.lamb * self.model.regul
                loss.backward()

                nn.utils.clip_grad_norm_(self.model.parameters(),
                                         self.args.grad_clip)
                self.optimizer.step()
                self.prox_operator()

                epoch_loss += loss.data.cpu().numpy()
                n_iters += 1

            scheduler.step()

            if self.cluster_way == "scu":
                self.rela_cluster = self.cluster()
                self.rela_cluster_history.append(self.rela_cluster)
                self.rela_to_dict(self.rela_cluster)

            # train controller
            self.train_controller()

            # derive structs
            self.time_tot += time.time(
            ) - start  # evaluation for the derived architecture is unnessary in searching procedure
            derived_struct, test_mrr = self.derive(sample_num=1)

            print(
                "Epoch: %d/%d, Search Time=%.2f, Loss=%.2f, Sampled Val MRR=%.8f, Tst MRR=%.8f"
                %
                (epoch + 1, self.args.n_oas_epoch, self.time_tot, epoch_loss /
                 n_train, self.derived_raward_history[-1], test_mrr))
Exemple #12
0
class TrackerSiamFC(Tracker):  #定义一个追踪器
    def __init__(self, net_path=None, **kwargs):
        super(TrackerSiamFC, self).__init__(net_path, True)

        self.cfg = self.parse_args(**kwargs)  #加载一些初始化的参数

        # setup GPU device if available
        self.cuda = torch.cuda.is_available()

        self.device = torch.device(
            'cuda:0' if self.cuda else 'cpu')  #指定 GPU0 来进行训练

        # setup model
        self.net = Net(backbone=AlexNetV1(), head=SiamFC(self.cfg.out_scale))
        ops.init_weights(self.net)  #对网络权重进行初始化???

        # load checkpoint if provided
        if net_path is not None:  #加载训练好的网络模型
            self.net.load_state_dict(
                torch.load(net_path,
                           map_location=lambda storage, loc: storage))
        self.net = self.net.to(self.device)  #将模型加载到GPU上

        # setup criterion
        self.criterion = BalancedLoss()

        # setup optimizer
        self.optimizer = optim.SGD(self.net.parameters(),
                                   lr=self.cfg.initial_lr,
                                   weight_decay=self.cfg.weight_decay,
                                   momentum=self.cfg.momentum)

        # setup lr scheduler  #动态调整学习率 这个是怎么计算的??
        gamma = np.power(self.cfg.ultimate_lr / self.cfg.initial_lr,
                         1.0 / self.cfg.epoch_num)

        #gamma=0.87
        self.lr_scheduler = ExponentialLR(self.optimizer, gamma)

    def parse_args(self, **kwargs):
        # default parameters
        cfg = {
            # basic parameters
            'out_scale': 0.001,
            'exemplar_sz': 127,  #第一帧
            'instance_sz': 255,  #
            'context': 0.5,
            # inference parameters
            'scale_num': 3,
            'scale_step': 1.0375,
            'scale_lr': 0.59,
            'scale_penalty': 0.9745,
            'window_influence': 0.176,
            'response_sz': 17,
            'response_up': 16,
            'total_stride': 8,
            # train parameters
            'epoch_num': 50,
            'batch_size': 8,
            'num_workers': 8,  #原来是32 被我修改成8
            'initial_lr': 1e-2,  #0.01
            'ultimate_lr': 1e-5,  #0.000052
            'weight_decay': 5e-4,  #正则参数
            'momentum': 0.9,
            'r_pos': 16,  #
            'r_neg': 0
        }  #

        for key, val in kwargs.items():
            if key in cfg:
                cfg.update({key: val})
        return namedtuple('Config', cfg.keys())(**cfg)

    #namedtuple比tuple更强大,与list不同的是,你不能改变tuple中元素的数值
    #Namedtuple比普通tuple具有更好的可读性,可以使代码更易于维护
    #为了构造一个namedtuple需要两个参数,分别是tuple的名字和其中域的名字
    '''禁止计算局部梯度
    方法1 使用装饰器 @torch.no_gard()修饰的函数,在调用时不允许计算梯度
    方法2 # 将不用计算梯度的变量放在 with torch.no_grad()里
    '''

    @torch.no_grad()  #
    def init(self, img, box):
        # set to evaluation mode
        self.net.eval()

        # convert box to 0-indexed and center based [y, x, h, w]
        box = np.array([
            box[1] - 1 + (box[3] - 1) / 2, box[0] - 1 +
            (box[2] - 1) / 2, box[3], box[2]
        ],
                       dtype=np.float32)
        self.center, self.target_sz = box[:2], box[
            2:]  #最原始的图片大小 从groundtruth读取

        # create hanning window  response_up=16 ;  response_sz=17 ; self.upscale_sz=272
        self.upscale_sz = self.cfg.response_up * self.cfg.response_sz
        self.hann_window = np.outer(  # np.outer 如果a,b是高维数组,函数会自动将其flatten成1维 ,用来求外积
            np.hanning(self.upscale_sz), np.hanning(self.upscale_sz))
        self.hann_window /= self.hann_window.sum()  #???

        # search scale factors
        self.scale_factors = self.cfg.scale_step**np.linspace(  #linspace 在start和stop之间返回均匀间隔的数据
            -(self.cfg.scale_num // 2),  #//py3中双斜杠代表向下取整
            self.cfg.scale_num // 2,
            self.cfg.scale_num)

        # exemplar and search sizes  self.cfg.context=1/2
        context = self.cfg.context * np.sum(
            self.target_sz)  # 引入margin:2P=(长+宽)× 1/2
        self.z_sz = np.sqrt(
            np.prod(self.target_sz +
                    context))  # ([长,宽]+2P) x 2 添加 padding  没有乘以缩放因子
        self.x_sz = self.z_sz * self.cfg.instance_sz / self.cfg.exemplar_sz  # 226   没有乘以缩放因子
        # z是初始模板的大小 x是搜索区域
        # exemplar image
        self.avg_color = np.mean(img,
                                 axis=(0, 1))  # 计算RGB通道的均值,使用图像均值进行padding
        z = ops.crop_and_resize(img,
                                self.center,
                                self.z_sz,
                                out_size=self.cfg.exemplar_sz,
                                border_value=self.avg_color)

        #对所有的图片进行预处理,得到127x127大小的patch
        # exemplar features
        z = torch.from_numpy(z).to(self.device).permute(
            2, 0, 1).unsqueeze(0).float()

        self.kernel = self.net.backbone(z)

    @torch.no_grad()  #
    def update(self, img):
        # set to evaluation mode
        self.net.eval()

        # search images
        x = [
            ops.crop_and_resize(img,
                                self.center,
                                self.x_sz * f,
                                out_size=self.cfg.instance_sz,
                                border_value=self.avg_color)
            for f in self.scale_factors
        ]
        x = np.stack(x, axis=0)
        x = torch.from_numpy(x).to(self.device).permute(0, 3, 1, 2).float()

        # responses
        x = self.net.backbone(x)
        responses = self.net.head(self.kernel, x)
        responses = responses.squeeze(1).cpu().numpy()

        # upsample responses and penalize scale changes
        responses = np.stack([
            cv2.resize(u, (self.upscale_sz, self.upscale_sz),
                       interpolation=cv2.INTER_CUBIC) for u in responses
        ])
        responses[:self.cfg.scale_num // 2] *= self.cfg.scale_penalty
        responses[self.cfg.scale_num // 2 + 1:] *= self.cfg.scale_penalty

        # peak scale
        scale_id = np.argmax(np.amax(responses, axis=(1, 2)))

        # peak location
        response = responses[scale_id]
        response -= response.min()
        response /= response.sum() + 1e-16
        response = (1 - self.cfg.window_influence) * response + \
            self.cfg.window_influence * self.hann_window
        loc = np.unravel_index(response.argmax(), response.shape)

        # locate target center
        disp_in_response = np.array(loc) - (self.upscale_sz - 1) / 2
        disp_in_instance = disp_in_response * \
            self.cfg.total_stride / self.cfg.response_up
        disp_in_image = disp_in_instance * self.x_sz * \
            self.scale_factors[scale_id] / self.cfg.instance_sz
        self.center += disp_in_image

        # update target size
        scale = (1 - self.cfg.scale_lr
                 ) * 1.0 + self.cfg.scale_lr * self.scale_factors[scale_id]
        self.target_sz *= scale
        self.z_sz *= scale
        self.x_sz *= scale

        # return 1-indexed and left-top based bounding box  [x,y,w,h]
        box = np.array([
            self.center[1] + 1 - (self.target_sz[1] - 1) / 2,
            self.center[0] + 1 - (self.target_sz[0] - 1) / 2,
            self.target_sz[1], self.target_sz[0]
        ])

        return box

    def track(self, img_files, box, visualize=False):  # x,y,w,h
        frame_num = len(img_files)
        boxes = np.zeros((frame_num, 4))
        boxes[0] = box
        times = np.zeros(frame_num)

        for f, img_file in enumerate(img_files):
            img = ops.read_image(img_file)
            begin = time.time()
            if f == 0:
                self.init(img, box)
            else:
                boxes[f, :] = self.update(img)
            times[f] = time.time() - begin

            if visualize:
                ops.show_image(img, boxes[f, :])

        return boxes, times

    def train_step(self, batch, backward=True):
        # set network mode
        self.net.train(backward)  # 训练模式

        # parse batch data
        z = batch[0].to(self.device, non_blocking=self.cuda)
        x = batch[1].to(self.device, non_blocking=self.cuda)

        with torch.set_grad_enabled(backward):
            # inference
            responses = self.net(z, x)

            # calculate loss
            labels = self._create_labels(responses.size())
            loss = self.criterion(responses, labels)

            if backward:
                # back propagation
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

        return loss.item()

    # 在禁止计算梯度下调用被允许计算梯度的函数
    @torch.enable_grad()
    def train_over(self, seqs, val_seqs=None, save_dir='pretrained'):
        # set to train mode
        self.net.train()
        # create save_dir folder
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        # setup dataset
        transforms = SiamFCTransforms(
            exemplar_sz=self.cfg.exemplar_sz,  #127
            instance_sz=self.cfg.instance_sz,  #255
            context=self.cfg.context)  # 0.5 ???

        dataset = Pair(seqs=seqs, transforms=transforms)

        # setup dataloader
        dataloader = DataLoader(dataset,
                                batch_size=self.cfg.batch_size,
                                shuffle=True,
                                num_workers=self.cfg.num_workers,
                                pin_memory=self.cuda,
                                drop_last=True)

        # loop over epochs
        for epoch in range(self.cfg.epoch_num):
            # update lr at each epoch
            self.lr_scheduler.step(epoch=epoch)

            # loop over dataloader
            for it, batch in enumerate(dataloader):

                loss = self.train_step(batch, backward=True)

                print('Epoch: {} [{}/{}] Loss: {:.5f}'.format(
                    epoch + 1, it + 1, len(dataloader), loss))

                sys.stdout.flush()

            # save checkpoint
            if not os.path.exists(save_dir):

                os.makedirs(save_dir)

            net_path = os.path.join(save_dir,
                                    'siamfc_alexnet_e%d.pth' % (epoch + 1))

            torch.save(self.net.state_dict(), net_path)

    def _create_labels(self, size):
        # skip if same sized labels already created
        if hasattr(self, 'labels') and self.labels.size() == size:

            return self.labels

        def logistic_labels(x, y, r_pos, r_neg):

            dist = np.abs(x) + np.abs(y)  # block distance

            labels = np.where(
                dist <= r_pos, np.ones_like(x),
                np.where(dist < r_neg,
                         np.ones_like(x) * 0.5, np.zeros_like(x)))

            return labels

        # distances along x- and y-axis
        n, c, h, w = size
        x = np.arange(w) - (w - 1) / 2
        y = np.arange(h) - (h - 1) / 2
        x, y = np.meshgrid(x, y)

        # create logistic labels
        r_pos = self.cfg.r_pos / self.cfg.total_stride
        r_neg = self.cfg.r_neg / self.cfg.total_stride
        labels = logistic_labels(x, y, r_pos, r_neg)

        # repeat to size
        labels = labels.reshape((1, 1, h, w))
        labels = np.tile(labels, (n, c, 1, 1))

        # convert to tensors
        self.labels = torch.from_numpy(labels).to(self.device).float()

        return self.labels
Exemple #13
0
class Trainer_new(Trainer):
    def __init__(self, args, logger=None):
        self.args = args
        self.logger = logger
        self.best_dev_performance = 0.0
        self.device = torch.device('cuda' if args.use_cuda else 'cpu')
        self.learning_rate = self.args.lr
        self.decay_rate = self.args.decay_rate
        self.reset_time = 0
        self.data_load_kg_rs()
        self.load_pretrain(args)
        # self.model.freeze_part()

    def optim_def(self, lr):
        self.learning_rate = lr
        self.optim = optim.Adam(self.model.parameters(), lr=lr)
        if self.decay_rate:  #decay_rate > 0
            self.scheduler = ExponentialLR(self.optim, self.decay_rate)
        self.reset_time += 1

    def model_def(self):
        self.entity_total = len(self.e_map)
        self.relation_total = len(self.r_map)
        self.user_total = len(self.u_map)
        self.item_total = len(self.i_map)
        self.share_total = len(self.i_kg_map)
        self.model = init_model(self.args,
                                self.user_total,
                                self.item_total,
                                self.entity_total,
                                self.relation_total,
                                self.logger,
                                None,
                                None,
                                None,
                                share_total=self.share_total)

    def loss_def(self):
        self.loss_func_kg = torch.nn.BCEWithLogitsLoss(reduction="none")  #
        self.loss_reg = Regularization(self.model,
                                       weight_decay=self.args.l2_lambda,
                                       p=2)

    # @override
    def save_ckpt(self):
        checkpoint = {
            'model_state_dict': self.model.state_dict(),
            'best_dev_performance': self.best_dev_performance
        }
        model_name = os.path.join(self.args.checkpoint_dir,
                                  "{}.ckpt".format(self.args.experiment_name))
        torch.save(checkpoint, model_name)
        print("Save model as %s" % model_name)

    def train(self, start_epoch, end_epoch):
        eval_every = self.args.eval_every
        print("Strat Training------------------")
        self.show_norm()
        if self.args.norm_one:
            self.model.norm_one()
        elif self.args.norm_emb:
            self.model.norm_emb()
        if self.args.norm_user:
            self.model.norm_user()
        #self.evaluator_kg.evaluate(self.kg_eval_loader, "EVAL")
        #self.evaluator_kg.evaluate(self.kg_test_loader, "TEST")
        # if self.args.need_pretrain:
        #     self.pretrain(start_epoch, end_epoch)
        for epoch in range(start_epoch, end_epoch + 1):
            st = time.time()
            train_loss, pos_loss, neg_loss = self.train_epoch_kg()
            self.show_norm()
            print("Epoch: {}, loss: {:.4f}, pos: {:.4f},"
                  " neg: {:.4f}, time: {}".format(epoch + 1, train_loss,
                                                  pos_loss, neg_loss,
                                                  time.time() - st))
            if (epoch + 1) % eval_every == 0 and epoch > 0:
                mr_raw, hits_10_raw, mrr_raw,\
                mr_fil, hits_10_fil, mrr_fil = self.evaluator_kg.evaluate(self.kg_eval_loader, "EVAL")
                #mr_raw, hits_10_raw, mrr_raw, mr_fil, hits_10_fil, mrr_fil = self.evaluate(self.kg_eval_loader, "EVAL")
                if mrr_fil > self.best_dev_performance:
                    self.best_dev_performance = mrr_fil
                    self.save_ckpt()
                    self.reset_time = 0
                # To guarantee the correctness of evaluation
                else:
                    self.logger.info(
                        'No improvement after one evaluation iter.')
                    self.reset_time += 1
                if self.reset_time >= 5:
                    self.logger.info(
                        'No improvement after 5 evaluation. Early Stopping.')
                    break
        self.logger.info('Train Done! Evaluate on testset with saved model')
        print("End Training------------------")
        self.evaluate_best()

    def dis_step(self, heads, rels, tails, neg_tails, sample_weight, user_rep):
        self.optim.zero_grad()
        query_now, norm_q = self.model.form_query(heads, rels, user_rep)
        preds, norm_p = self.model.query_judge(query_now, tails)
        preds_neg, norm_n = self.model.query_judge(query_now, neg_tails)

        random_ratio = self.args.label_smoothing_epsilon / self.args.n_sample
        answers_true = torch.ones_like(preds) * (
            1.0 - self.args.label_smoothing_epsilon)
        answers_false = torch.zeros_like(preds_neg) + random_ratio

        loss_pos = self.loss_func_kg(preds, answers_true)
        loss_pos = (loss_pos * sample_weight).sum()  # / sample_weight.sum()

        loss_neg = torch.sum(self.loss_func_kg(preds_neg, answers_false),
                             dim=1)
        loss_neg = (loss_neg * sample_weight).sum()  # / sample_weight.sum()

        # loss_reg = self.loss_reg(self.model)

        loss = loss_pos + loss_neg

        # losses_reg.append(loss_reg.item())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(
            [param for name, param in self.model.named_parameters()],
            self.args.clipping_max_value)
        self.optim.step()
        if self.args.norm_one:
            self.model.norm_one()
        elif self.args.norm_emb:
            self.model.norm_emb()
        if self.args.norm_user:
            self.model.norm_user()
        return loss.item(), loss_pos.item(), loss_neg.item()

    def train_epoch_kg(self):
        path_list = self.sample_epoch_edges()
        self.model.train()
        losses = []
        losses_pos = []
        losses_neg = []
        step = 0
        path_reverse = self.sample_epoch_reverse()
        rgcn_kernel = self.get_rgcn_kenel(path_reserve=path_reverse,
                                          user_set=set(range(self.user_total)))
        for heads, rels, tails, neg_tails, sample_weight in self.kg_train_loader:
            step += 1
            self.optim.zero_grad()
            max_hop, ent_hop, hop_map, unreach_ents, new_order = self.batch_reserve_ents(
                heads, path_list)
            rs_graph, kg_graph = make_kernel_batch(path_list, max_hop, ent_hop,
                                                   hop_map)
            batch_order = torch.LongTensor(new_order).to(self.device)
            user_rep = self.model.fetch_user_batch(rgcn_kernel=rgcn_kernel,
                                                   rs_graph=rs_graph,
                                                   kg_graph=kg_graph,
                                                   tp_hop=max_hop,
                                                   unreach_ids=unreach_ents,
                                                   batch_order=batch_order,
                                                   pretrain=False)
            heads = torch.LongTensor(heads).to(self.device)
            rels = torch.LongTensor(rels).to(self.device)
            tails = torch.LongTensor(tails).to(self.device)
            neg_tails = torch.LongTensor(neg_tails).to(self.device)
            sample_weight = sample_weight.float().to(self.device).squeeze()

            loss, loss_pos, loss_neg = self.dis_step(heads, rels, tails,
                                                     neg_tails, sample_weight,
                                                     user_rep)
            losses.append(loss)
            losses_pos.append(loss_pos)
            losses_neg.append(loss_neg)
        if self.decay_rate:
            self.scheduler.step()
        mean_losses = np.mean(losses)
        pos_loss = np.mean(losses_pos)
        neg_loss = np.mean(losses_neg)
        return mean_losses, pos_loss, neg_loss
Exemple #14
0
class TrackerSiamFC(Tracker):
    def __init__(self, net_path=None, **kwargs):
        super(TrackerSiamFC, self).__init__('SiamFC', True)
        self.cfg = self.parse_args(**kwargs)
        self.update_id = 2
        # setup GPU device if available
        self.cuda = torch.cuda.is_available()
        self.device = torch.device('cuda:0' if self.cuda else 'cpu')

        # setup model 모델 설정하기
        self.net = Net(backbone=AlexNetV1(), head=SiamFC(self.cfg.out_scale))

        ops.init_weights(self.net)

        # load checkpoint if provided
        if net_path is not None:
            self.net.load_state_dict(
                torch.load(net_path,
                           map_location=lambda storage, loc: storage))
        self.net = self.net.to(self.device)

        # setup criterion
        self.criterion = BalancedLoss()

        # setup optimizer
        self.optimizer = optim.SGD(self.net.parameters(),
                                   lr=self.cfg.initial_lr,
                                   weight_decay=self.cfg.weight_decay,
                                   momentum=self.cfg.momentum)

        # setup lr scheduler
        gamma = np.power(self.cfg.ultimate_lr / self.cfg.initial_lr,
                         1.0 / self.cfg.epoch_num)

        self.lr_scheduler = ExponentialLR(self.optimizer, gamma)

    def parse_args(self, **kwargs):
        # default parameters
        cfg = {
            # basic parameters
            'out_scale': 0.001,
            'exemplar_sz': 127,
            'instance_sz': 255,
            'context': 0.5,
            # inference parameters
            'scale_num': 9,
            'scale_num_res': 3,
            'scale_step': 1.0375,
            'scale_lr': 0.59,
            'scale_penalty': 0.9745,
            'window_influence': 0.176,
            'response_sz': 17,  #17 x 17
            'response_up': 16,
            'total_stride': 8,
            # train parameters
            'epoch_num': 50,
            'batch_size': 8,
            'num_workers': 32,
            'initial_lr': 1e-2,
            'ultimate_lr': 1e-5,
            'weight_decay': 5e-4,
            'momentum': 0.9,
            'r_pos': 16,
            'r_neg': 0
        }

        for key, val in kwargs.items():
            if key in cfg:
                cfg.update({key: val})
        return namedtuple('Config', cfg.keys())(**cfg)

    @torch.no_grad()
    def init(self, img, box):
        # set to evaluation mode
        self.net.eval()

        # convert box to 0-indexed and center based [y, x, h, w]
        box = np.array([
            box[1] - 1 + (box[3] - 1) / 2, box[0] - 1 +
            (box[2] - 1) / 2, box[3], box[2]
        ],
                       dtype=np.float32)
        self.center, self.target_sz = box[:2], box[2:]

        # create hanning window
        self.upscale_sz = self.cfg.response_up * self.cfg.response_sz
        self.hann_window = np.outer(
            np.hanning(
                self.upscale_sz
            ),  #np.hanning() https://numpy.org/doc/stable/reference/generated/numpy.hanning.html
            np.hanning(self.upscale_sz))
        self.hann_window /= self.hann_window.sum()

        # search scale factors
        self.scale_factors = self.cfg.scale_step**np.linspace(
            -(self.cfg.scale_num // 8), self.cfg.scale_num // 8,
            self.cfg.scale_num)

        # search scale factors
        self.scale_factors_res = self.cfg.scale_step**np.linspace(
            -(self.cfg.scale_num_res // 2), self.cfg.scale_num_res // 2,
            self.cfg.scale_num_res)

        # exemplar and search sizes
        context = self.cfg.context * np.sum(self.target_sz)
        self.z_sz = np.sqrt(np.prod(self.target_sz + context))
        self.x_sz = self.z_sz * self.cfg.instance_sz / self.cfg.exemplar_sz

        # exemplar image
        self.avg_color = np.mean(img, axis=(0, 1))
        z = ops.crop_and_resize(img,
                                self.center,
                                self.z_sz,
                                out_size=self.cfg.exemplar_sz,
                                border_value=self.avg_color)

        # exemplar features
        z = torch.from_numpy(z).to(self.device).permute(
            2, 0, 1).unsqueeze(0).float()
        self.kernel = self.net.backbone(z)  #AlexNetV1()

    @torch.no_grad()
    def update(self, count, img_files):
        # set to evaluation mode
        self.net.eval()
        x = []

        #setting scale_num and scale_factors
        if count == 1:
            for fr, img_file in enumerate(img_files):
                img = ops.read_image(img_file)
                x.append(
                    ops.crop_and_resize(img,
                                        self.center,
                                        self.x_sz * 1.0,
                                        out_size=self.cfg.instance_sz,
                                        border_value=self.avg_color))  #0.7

        else:
            for fr, img_file in enumerate(img_files):
                img = ops.read_image(img_file)
                y = [
                    ops.crop_and_resize(img,
                                        self.center,
                                        self.x_sz * f,
                                        out_size=self.cfg.instance_sz,
                                        border_value=self.avg_color)
                    for f in self.scale_factors_res
                ]
                x.extend(y)

        # search images
        search = x
        x = np.stack(x, axis=0)
        x = torch.from_numpy(x).to(self.device).permute(0, 3, 1, 2).float()

        # responses
        x = self.net.backbone(x)  #AlexNetV1()
        responses = self.net.head(self.kernel, x)
        responses = responses.squeeze(1).cpu().numpy()

        scale_num = self.cfg.scale_num
        scale_factors = self.scale_factors
        panalty_score = 8

        # upsample responses and penalize scale changes
        responses = np.stack([
            cv2.resize(u, (self.upscale_sz, self.upscale_sz),
                       interpolation=cv2.INTER_CUBIC) for u in responses
        ])
        responses[:scale_num // panalty_score] *= self.cfg.scale_penalty
        responses[scale_num // panalty_score + 1:] *= self.cfg.scale_penalty

        # peak scale
        pc = np.amax(responses, axis=(1, 2))
        scale_id = np.argmax(np.amax(responses, axis=(1, 2)))

        if count == 1:
            file_id = scale_id
        else:
            if scale_id < 3:
                file_id = 0
            elif scale_id > 2 and scale_id < 6:
                file_id = 1
            else:
                file_id = 2

        png = img_files[file_id].split('\\')[-1]
        png_id = int(png.split('.')[-2])

        # peak location
        response = responses[scale_id]
        response -= response.min()
        response /= response.sum() + 1e-16
        response = (1 - self.cfg.window_influence
                    ) * response + self.cfg.window_influence * self.hann_window
        loc = np.unravel_index(response.argmax(), response.shape)

        # locate target center
        disp_in_response = np.array(loc) - (self.upscale_sz - 1) / 2
        disp_in_instance = disp_in_response * self.cfg.total_stride / self.cfg.response_up
        disp_in_image = disp_in_instance * self.x_sz * scale_factors[
            scale_id] / self.cfg.instance_sz
        self.center += disp_in_image

        scale = (1 - self.cfg.scale_lr
                 ) * 1.0 + self.cfg.scale_lr * scale_factors[scale_id]
        self.target_sz *= scale
        self.z_sz *= scale
        self.x_sz *= scale

        # return 1-indexed and left-top based bounding box
        box = np.array([
            self.center[1] + 1 - (self.target_sz[1] - 1) / 2,
            self.center[0] + 1 - (self.target_sz[0] - 1) / 2,
            self.target_sz[1], self.target_sz[0]
        ])

        return box, responses, search, png_id

    def track(self, first_img, focal_dirs, show_imgs, box, visualize=False):
        frame_num = len(focal_dirs)
        boxes = np.zeros((frame_num, 4))
        boxes[0] = box
        times = np.zeros(frame_num)
        response = np.zeros((3, 272, 272))
        search = []
        picked = []
        gt = open("recode/bbox.txt", 'w')
        try:
            for f, focal_dir in enumerate(focal_dirs):
                focal_plane = glob.glob(focal_dir + "/*")

                begin = time.time()
                if f == 0:
                    self.init(first_img, box)
                elif f == 1:
                    sharp = sharpfiles(
                        focal_dir)  #focal_plane[27] #처음은 027.png
                    boxes[f, :], response, search, scale_id = self.update(
                        f, sharp)  #sharp
                else:
                    focal_planes = focal_plane[scale_id - 1:scale_id + 1 + 1]
                    boxes[f, :], response, search, scale_id = self.update(
                        f, focal_planes)  #focal 이미지 리스트
                times[f] = time.time() - begin

                if response[0][0][0] == 0.0:
                    visualize = False
                else:
                    visualize = True

                if visualize == True:
                    dimg = ops.read_image(show_imgs[f])
                    save_img, centerloc = ops.show_image(dimg,
                                                         boxes[f, :],
                                                         None,
                                                         num_name=f)
                    save_img = cv2.putText(save_img, str(scale_id), (100, 30),
                                           cv2.FONT_HERSHEY_SIMPLEX, 0.5,
                                           (0, 255, 0), 2)
                    savepath = "./image/{0:0=3d}.png".format(f)
                    cv2.imwrite(savepath, save_img)
                    #groundtruth
                    center_coo = "%d,%d,%d,%d\n" % (int(
                        centerloc[0]), int(centerloc[1]), int(
                            centerloc[2]), int(centerloc[3]))
                    gt.write(center_coo)

        finally:
            gt.close()

        return boxes, times, response

    def train_step(self, batch, backward=True):
        # set network mode
        self.net.train(backward)

        # parse batch data
        z = batch[0].to(self.device, non_blocking=self.cuda)
        x = batch[1].to(self.device, non_blocking=self.cuda)

        with torch.set_grad_enabled(backward):
            # inference
            responses = self.net(z, x)

            # calculate loss
            labels = self._create_labels(responses.size())
            loss = self.criterion(responses, labels)

            if backward:
                # back propagation
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

        return loss.item()

    @torch.enable_grad()
    def train_over(self, seqs, val_seqs=None, save_dir='pretrained'):
        # set to train mode
        self.net.train()

        # create save_dir folder
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        # setup dataset
        transforms = SiamFCTransforms(exemplar_sz=self.cfg.exemplar_sz,
                                      instance_sz=self.cfg.instance_sz,
                                      context=self.cfg.context)
        dataset = Pair(seqs=seqs, transforms=transforms)

        # setup dataloader
        dataloader = DataLoader(dataset,
                                batch_size=self.cfg.batch_size,
                                shuffle=True,
                                num_workers=self.cfg.num_workers,
                                pin_memory=self.cuda,
                                drop_last=True)

        # loop over epochs
        for epoch in range(self.cfg.epoch_num):
            # update lr at each epoch
            self.lr_scheduler.step(epoch=epoch)

            # loop over dataloader
            for it, batch in enumerate(dataloader):
                loss = self.train_step(batch, backward=True)
                print('Epoch: {} [{}/{}] Loss: {:.5f}'.format(
                    epoch + 1, it + 1, len(dataloader), loss))
                sys.stdout.flush()

            # save checkpoint
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            net_path = os.path.join(save_dir,
                                    'siamfc_alexnet_e%d.pth' % (epoch + 1))
            torch.save(self.net.state_dict(), net_path)

    def _create_labels(self, size):
        # skip if same sized labels already created
        if hasattr(self, 'labels') and self.labels.size() == size:
            return self.labels

        def logistic_labels(x, y, r_pos, r_neg):
            dist = np.abs(x) + np.abs(y)  # block distance
            labels = np.where(
                dist <= r_pos, np.ones_like(x),
                np.where(dist < r_neg,
                         np.ones_like(x) * 0.5, np.zeros_like(x)))
            return labels

        # distances along x- and y-axis
        n, c, h, w = size
        x = np.arange(w) - (w - 1) / 2
        y = np.arange(h) - (h - 1) / 2
        x, y = np.meshgrid(x, y)

        # create logistic labels
        r_pos = self.cfg.r_pos / self.cfg.total_stride
        r_neg = self.cfg.r_neg / self.cfg.total_stride
        labels = logistic_labels(x, y, r_pos, r_neg)

        # repeat to size
        labels = labels.reshape((1, 1, h, w))
        labels = np.tile(labels, (n, c, 1, 1))

        # convert to tensors
        self.labels = torch.from_numpy(labels).to(self.device).float()

        return self.labels
Exemple #15
0
class PPO:
    def __init__(
        self,
        seed: int,
        gamma: float,
        tau: float,
        clip_param: float,
        vf_c: float,
        ent_c: float,
        input_shape: int,
        hidden_units_value: list,
        hidden_units_actor: list,
        batch_size: int,
        lr: float,
        activation: str,
        optimizer_name: str,
        batch_norm_input: bool,
        batch_norm_value_out: bool,
        action_space,
        policy_type: str,
        init_pol_std: float,
        min_pol_std: float,
        beta_1: float = 0.9,
        beta_2: float = 0.999,
        eps_opt: float = 1e-07,
        lr_schedule: Optional[str] = None,
        exp_decay_rate: Optional[float] = None,
        step_size: Optional[int] = None,
        std_transform: str = "softplus",
        init_last_layers: str = "rescaled",
        rng=None,
        modelname: str = "PPO act_crt",
    ):

        if rng is not None:
            self.rng = rng
        else:
            self.rng = np.random.RandomState(seed)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # self.device = torch.device('cpu')
        self.gamma = gamma
        self.tau = tau
        self.clip_param = clip_param
        self.vf_c = vf_c
        self.ent_c = ent_c
        self.batch_size = batch_size
        self.beta_1 = beta_1
        self.eps_opt = eps_opt
        self.action_space = action_space
        self.policy_type = policy_type
        self.num_actions = self.action_space.get_n_actions(policy_type=self.policy_type)
        self.batch_norm_input = batch_norm_input

        self.experience = {
            "state": [],
            "action": [],
            "reward": [],
            "log_prob": [],
            "value": [],
            "returns": [],
            "advantage": [],
        }

        self.model = PPOActorCritic(
            seed,
            input_shape,
            activation,
            hidden_units_value,
            hidden_units_actor,
            self.num_actions,
            batch_norm_input,
            batch_norm_value_out,
            self.policy_type,
            init_pol_std,
            min_pol_std,
            std_transform,
            init_last_layers,
            modelname,
        )

        self.model.to(self.device)

        self.optimizer_name = optimizer_name
        if optimizer_name == "adam":
            self.optimizer = optim.Adam(
                self.model.parameters(),
                lr=lr,
                betas=(beta_1, beta_2),
                eps=eps_opt,
                weight_decay=0,
                amsgrad=False,
            )
        elif optimizer_name == "rmsprop":
            self.optimizer = optim.RMSprop(
                self.model.parameters(),
                lr=lr,
                alpha=beta_1,
                eps=eps_opt,
                weight_decay=0,
                momentum=0,
                centered=False,
            )

        if lr_schedule == "step":
            self.scheduler = StepLR(
                optimizer=self.optimizer, step_size=step_size, gamma=exp_decay_rate
            )

        elif lr_schedule == "exponential":
            self.scheduler = ExponentialLR(
                optimizer=self.optimizer, gamma=exp_decay_rate
            )
        else:
            self.scheduler = None

        if self.policy_type == 'continuous':
            self.std_hist = []
            self.entropy_hist = []
        elif self.policy_type == 'discrete':
            self.logits_hist = []
            self.entropy_hist = []

    def train(self, state, action, old_log_probs, return_, advantage):
        advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-5)

        self.model.train()
        dist, value = self.model(state)
        entropy = dist.entropy().mean()
        if self.policy_type == 'continuous':
            new_log_probs = dist.log_prob(action)
            # self.std_hist.append(self.model.log_std.exp().detach().cpu().numpy().ravel())
            # self.entropy_hist.append(entropy.detach().cpu().numpy().ravel())
        elif self.policy_type == 'discrete':
            new_log_probs = dist.log_prob(action.reshape(-1)).reshape(-1,1)
            # self.logits_hist.append(dist.logits.detach().cpu().numpy())
            # self.entropy_hist.append(entropy.detach().cpu().numpy().ravel())
        
        
        ratio = (new_log_probs - old_log_probs).exp()  # log properties
        surr1 = ratio * advantage
        surr2 = (
            torch.clamp(ratio, 1.0 / (1 + self.clip_param), 1.0 + self.clip_param)
            * advantage
        )

        actor_loss = -torch.min(surr1, surr2).mean()
        critic_loss = (return_ - value).pow(2).mean()

        # the loss is negated in order to be maximized
        self.loss = self.vf_c * critic_loss + actor_loss - self.ent_c * entropy

        self.optimizer.zero_grad()
        self.loss.backward()
        self.optimizer.step()

        if self.scheduler:
            self.scheduler.step()


    def act(self, states):
        # useful when the states are single dimensional
        self.model.eval()
        # make 1D tensor to 2D
        with torch.no_grad():
            states = torch.from_numpy(states).float().unsqueeze(0)
            states = states.to(self.device)
            return self.model(states)

    def compute_gae(self, next_value, recompute_value=False):

        if recompute_value:
            self.model.eval()
            with torch.no_grad():
                _, values = self.model(
                    torch.Tensor(self.experience["state"]).to(self.device)
                )
            self.experience["value"] = [
                np.array(v, dtype=float) for v in values.detach().cpu().tolist()
            ]
            # for i in range(len(self.experience["value"])):
            #     _, value = self.act(self.experience["state"][i])
            #     self.experience["value"][i] = value.detach().cpu().numpy().ravel()

        rewards = self.experience["reward"]
        values = self.experience["value"]

        values = values + [next_value]
        gae = 0
        returns = []
        for step in reversed(range(len(rewards))):
            delta = rewards[step] + self.gamma * values[step + 1] - values[step]
            gae = delta + self.gamma * self.tau * gae
            returns.insert(0, gae + values[step])

        # add estimated returns and advantages to the experience
        self.experience["returns"] = returns

        advantage = [returns[i] - values[i] for i in range(len(returns))]
        self.experience["advantage"] = advantage

    # add way to reset experience after one rollout
    def add_experience(self, exp):
        for key, value in exp.items():
            self.experience[key].append(value)

    def reset_experience(self):

        self.experience = {
            "state": [],
            "action": [],
            "reward": [],
            "log_prob": [],
            "value": [],
            "returns": [],
            "advantage": [],
        }

    def ppo_iter(self):
        # pick a batch from the rollout
        states = np.asarray(self.experience["state"])
        actions = np.asarray(self.experience["action"])
        log_probs = np.asarray(self.experience["log_prob"])
        returns = np.asarray(self.experience["returns"])
        advantage = np.asarray(self.experience["advantage"])

        len_rollout = states.shape[0]
        ids = self.rng.permutation(len_rollout)
        ids = np.array_split(ids, len_rollout // self.batch_size)
        for i in range(len(ids)):

            yield (
                torch.from_numpy(states[ids[i], :]).float().to(self.device),
                torch.from_numpy(actions[ids[i], :]).float().to(self.device),
                torch.from_numpy(log_probs[ids[i], :]).float().to(self.device),
                torch.from_numpy(returns[ids[i], :]).float().to(self.device),
                torch.from_numpy(advantage[ids[i], :]).float().to(self.device),
            )

    def getBack(self, var_grad_fn):
        print(var_grad_fn)
        for n in var_grad_fn.next_functions:
            if n[0]:
                try:
                    tensor = getattr(n[0], "variable")
                    print(n[0])
                    print("Tensor with grad found:", tensor)
                    print(" - gradient:", tensor.grad)
                    print()
                except AttributeError as e:
                    self.getBack(n[0])

    def save_diagnostics(self,path):
        if self.policy_type == 'continuous':
            np.save(os.path.join(path, "std_hist"), np.array(self.std_hist))
            np.save(os.path.join(path, "entropy_hist"), np.array(self.entropy_hist))
        elif self.policy_type == 'discrete':
            np.save(os.path.join(path, "logits_hist"), np.array(self.logits_hist, dtype=object))
            np.save(os.path.join(path, "entropy_hist"), np.array(self.entropy_hist))
Exemple #16
0
def main() -> None:
    """Entrypoint.
    """
    config: Any = importlib.import_module(args.config)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_data = tx.data.MonoTextData(config.train_data_hparams, device=device)
    val_data = tx.data.MonoTextData(config.val_data_hparams, device=device)
    test_data = tx.data.MonoTextData(config.test_data_hparams, device=device)

    iterator = tx.data.DataIterator({
        "train": train_data,
        "valid": val_data,
        "test": test_data
    })

    opt_vars = {
        'learning_rate': config.lr_decay_hparams["init_lr"],
        'best_valid_nll': 1e100,
        'steps_not_improved': 0,
        'kl_weight': config.kl_anneal_hparams["start"]
    }

    decay_cnt = 0
    max_decay = config.lr_decay_hparams["max_decay"]
    decay_factor = config.lr_decay_hparams["decay_factor"]
    decay_ts = config.lr_decay_hparams["threshold"]

    save_dir = f"./models/{config.dataset}"

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    suffix = f"{config.dataset}_{config.decoder_type}Decoder.ckpt"

    save_path = os.path.join(save_dir, suffix)

    # KL term annealing rate
    anneal_r = 1.0 / (config.kl_anneal_hparams["warm_up"] *
                      (len(train_data) / config.batch_size))

    vocab = train_data.vocab
    model = VAE(train_data.vocab.size, config)
    model.to(device)

    start_tokens = torch.full((config.batch_size, ),
                              vocab.bos_token_id,
                              dtype=torch.long).to(device)
    end_token = vocab.eos_token_id
    optimizer = tx.core.get_optimizer(params=model.parameters(),
                                      hparams=config.opt_hparams)
    scheduler = ExponentialLR(optimizer, decay_factor)

    def _run_epoch(epoch: int, mode: str, display: int = 10) \
            -> Tuple[Tensor, float]:
        iterator.switch_to_dataset(mode)

        if mode == 'train':
            model.train()
            opt_vars["kl_weight"] = min(1.0, opt_vars["kl_weight"] + anneal_r)

            kl_weight = opt_vars["kl_weight"]
        else:
            model.eval()
            kl_weight = 1.0
        step = 0
        start_time = time.time()
        num_words = 0
        nll_total = 0.

        avg_rec = tx.utils.AverageRecorder()
        for batch in iterator:
            ret = model(batch, kl_weight, start_tokens, end_token)
            if mode == "train":
                opt_vars["kl_weight"] = min(1.0,
                                            opt_vars["kl_weight"] + anneal_r)
                kl_weight = opt_vars["kl_weight"]
                ret["nll"].backward()
                optimizer.step()
                optimizer.zero_grad()

            batch_size = len(ret["lengths"])
            num_words += torch.sum(ret["lengths"]).item()
            nll_total += ret["nll"].item() * batch_size
            avg_rec.add([
                ret["nll"].item(), ret["kl_loss"].item(),
                ret["rc_loss"].item()
            ], batch_size)
            if step % display == 0 and mode == 'train':
                nll = avg_rec.avg(0)
                klw = opt_vars["kl_weight"]
                KL = avg_rec.avg(1)
                rc = avg_rec.avg(2)
                log_ppl = nll_total / num_words
                ppl = math.exp(log_ppl)
                time_cost = time.time() - start_time

                print(
                    f"{mode}: epoch {epoch}, step {step}, nll {nll:.4f}, "
                    f"klw {klw:.4f}, KL {KL:.4f}, rc {rc:.4f}, "
                    f"log_ppl {log_ppl:.4f}, ppl {ppl:.4f}, "
                    f"time_cost {time_cost:.1f}",
                    flush=True)

            step += 1

        nll = avg_rec.avg(0)
        KL = avg_rec.avg(1)
        rc = avg_rec.avg(2)
        log_ppl = nll_total / num_words
        ppl = math.exp(log_ppl)
        print(f"\n{mode}: epoch {epoch}, nll {nll:.4f}, KL {KL:.4f}, "
              f"rc {rc:.4f}, log_ppl {log_ppl:.4f}, ppl {ppl:.4f}")
        return nll, ppl  # type: ignore

    @torch.no_grad()
    def _generate(start_tokens: torch.LongTensor,
                  end_token: int,
                  filename: Optional[str] = None):
        ckpt = torch.load(args.model)
        model.load_state_dict(ckpt['model'])
        model.eval()

        batch_size = train_data.batch_size

        dst = MultivariateNormalDiag(loc=torch.zeros(batch_size,
                                                     config.latent_dims),
                                     scale_diag=torch.ones(
                                         batch_size, config.latent_dims))

        latent_z = dst.rsample().to(device)

        helper = model.decoder.create_helper(decoding_strategy='infer_sample',
                                             start_tokens=start_tokens,
                                             end_token=end_token)
        outputs = model.decode(helper=helper,
                               latent_z=latent_z,
                               max_decoding_length=100)

        sample_tokens = vocab.map_ids_to_tokens_py(outputs.sample_id.cpu())

        if filename is None:
            fh = sys.stdout
        else:
            fh = open(filename, 'w', encoding='utf-8')

        for sent in sample_tokens:
            sent = tx.utils.compat_as_text(list(sent))
            end_id = len(sent)
            if vocab.eos_token in sent:
                end_id = sent.index(vocab.eos_token)
            fh.write(' '.join(sent[:end_id + 1]) + '\n')

        print('Output done')
        fh.close()

    if args.mode == "predict":
        _generate(start_tokens, end_token, args.out)
        return
    # Counts trainable parameters
    total_parameters = sum(param.numel() for param in model.parameters())
    print(f"{total_parameters} total parameters")

    best_nll = best_ppl = 0.

    for epoch in range(config.num_epochs):
        _, _ = _run_epoch(epoch, 'train', display=200)
        val_nll, _ = _run_epoch(epoch, 'valid')
        test_nll, test_ppl = _run_epoch(epoch, 'test')

        if val_nll < opt_vars['best_valid_nll']:
            opt_vars['best_valid_nll'] = val_nll
            opt_vars['steps_not_improved'] = 0
            best_nll = test_nll
            best_ppl = test_ppl

            states = {
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict()
            }
            torch.save(states, save_path)
        else:
            opt_vars['steps_not_improved'] += 1
            if opt_vars['steps_not_improved'] == decay_ts:
                old_lr = opt_vars['learning_rate']
                opt_vars['learning_rate'] *= decay_factor
                opt_vars['steps_not_improved'] = 0
                new_lr = opt_vars['learning_rate']
                ckpt = torch.load(save_path)
                model.load_state_dict(ckpt['model'])
                optimizer.load_state_dict(ckpt['optimizer'])
                scheduler.load_state_dict(ckpt['scheduler'])
                scheduler.step()
                print(f"-----\nchange lr, old lr: {old_lr}, "
                      f"new lr: {new_lr}\n-----")

                decay_cnt += 1
                if decay_cnt == max_decay:
                    break

    print(f"\nbest testing nll: {best_nll:.4f},"
          f"best testing ppl {best_ppl:.4f}\n")
Exemple #17
0
                    loss_kld=loss_kld_mean.measure,
                    epoch=i + 1)

        if slurm:
            progress.update(
                progress.value,
                loss_nle=loss_nle_mean.measure,
                loss_encoder=loss_encoder_mean.measure,
                loss_decoder=loss_decoder_mean.measure,
                loss_discriminator=loss_discriminator_mean.measure,
                loss_mse_layer=loss_reconstruction_layer_mean.measure,
                loss_kld=loss_kld_mean.measure,
                epoch=i + 1)

        # EPOCH END
        lr_encoder.step()
        lr_decoder.step()
        lr_discriminator.step()
        margin *= decay_margin
        equilibrium *= decay_equilibrium
        torch.save({
            'epoch': step_index,
            "net": net.module.state_dict()
        }, save_path + f'model_epoch_{step_index}.tar')
        #margin non puo essere piu alto di equilibrium
        if margin > equilibrium:
            equilibrium = margin
        lambda_mse *= decay_mse
        if lambda_mse > 1:
            lambda_mse = 1
        progress.finish()
class SamplingMultitaskTrainer:
    def __init__(self,
                 dataset=None,
                 model_name=None,
                 model_params=None,
                 trainer_params=None,
                 restore=None,
                 device=None,
                 pretrained_embeddings_path=None,
                 tokenizer_path=None):

        self.graph_model = model_name(dataset.g, **model_params).to(device)
        self.model_params = model_params
        self.trainer_params = trainer_params
        self.device = device
        self.epoch = 0
        self.batch = 0
        self.dtype = torch.float32
        self.create_node_embedder(
            dataset,
            tokenizer_path,
            n_dims=model_params["h_dim"],
            pretrained_path=pretrained_embeddings_path,
            n_buckets=trainer_params["embedding_table_size"])

        self.summary_writer = SummaryWriter(self.model_base_path)

        self.ee_node_name = ElementEmbedderWithBpeSubwords(
            elements=dataset.load_node_names(),
            nodes=dataset.nodes,
            emb_size=self.elem_emb_size,
            tokenizer_path=tokenizer_path).to(self.device)

        self.ee_var_use = ElementEmbedderWithBpeSubwords(
            elements=dataset.load_var_use(),
            nodes=dataset.nodes,
            emb_size=self.elem_emb_size,
            tokenizer_path=tokenizer_path).to(self.device)

        self.ee_api_call = ElementEmbedderBase(
            elements=dataset.load_api_call(),
            nodes=dataset.nodes,
            compact_dst=False,
            dst_to_global=True)

        self.lp_node_name = LinkPredictor(self.ee_node_name.emb_size +
                                          self.graph_model.emb_size).to(
                                              self.device)
        self.lp_var_use = LinkPredictor(self.ee_var_use.emb_size +
                                        self.graph_model.emb_size).to(
                                            self.device)
        self.lp_api_call = LinkPredictor(self.graph_model.emb_size +
                                         self.graph_model.emb_size).to(
                                             self.device)

        if restore:
            self.restore_from_checkpoint(self.model_base_path)

        self.optimizer = self._create_optimizer()

        self.lr_scheduler = ExponentialLR(self.optimizer, gamma=1.0)
        self.best_score = BestScoreTracker()

        self._create_loaders(*self._get_training_targets())

    def create_node_embedder(self,
                             dataset,
                             tokenizer_path,
                             n_dims=None,
                             pretrained_path=None,
                             n_buckets=500000):
        from SourceCodeTools.nlp.embed.fasttext import load_w2v_map

        if pretrained_path is not None:
            pretrained = load_w2v_map(pretrained_path)
        else:
            pretrained = None

        if pretrained_path is None and n_dims is None:
            raise ValueError(
                f"Specify embedding dimensionality or provide pretrained embeddings"
            )
        elif pretrained_path is not None and n_dims is not None:
            assert n_dims == pretrained.n_dims, f"Requested embedding size and pretrained embedding " \
                                                f"size should match: {n_dims} != {pretrained.n_dims}"
        elif pretrained_path is not None and n_dims is None:
            n_dims = pretrained.n_dims

        if pretrained is not None:
            logging.info(f"Loading pretrained embeddings...")
        logging.info(f"Input embedding size is {n_dims}")

        self.node_embedder = NodeEmbedder(
            nodes=dataset.nodes,
            emb_size=n_dims,
            # tokenizer_path=tokenizer_path,
            dtype=self.dtype,
            pretrained=dataset.buckets_from_pretrained_embeddings(
                pretrained_path, n_buckets)
            if pretrained_path is not None else None,
            n_buckets=n_buckets)

        # self.node_embedder(node_type="node_", node_ids=torch.LongTensor([0]))
        # self.node_embedder(node_type="node_", node_ids=torch.LongTensor([13749]))
        # self.node_embedder(node_type="node_", node_ids=torch.LongTensor([13754]))

        # node_, 0 matplotlib
        # node_ 13749        Renderer
        # node_  13754 ▁renderer

        # print()

    @property
    def lr(self):
        return self.trainer_params['lr']

    @property
    def batch_size(self):
        return self.trainer_params['batch_size']

    @property
    def sampling_neighbourhood_size(self):
        return self.trainer_params['sampling_neighbourhood_size']

    @property
    def neg_sampling_factor(self):
        return self.trainer_params['neg_sampling_factor']

    @property
    def epochs(self):
        return self.trainer_params['epochs']

    @property
    def elem_emb_size(self):
        return self.trainer_params['elem_emb_size']

    @property
    def node_name_file(self):
        return self.trainer_params['node_name_file']

    @property
    def var_use_file(self):
        return self.trainer_params['var_use_file']

    @property
    def call_seq_file(self):
        return self.trainer_params['call_seq_file']

    @property
    def model_base_path(self):
        return self.trainer_params['model_base_path']

    @property
    def pretraining(self):
        return self.epoch >= self.trainer_params['pretraining_phase']

    @property
    def do_save(self):
        return self.trainer_params['save_checkpoints']

    # def _extract_embed(self, node_embed, input_nodes):
    #     emb = {}
    #     for node_type, nid in input_nodes.items():
    #         emb[node_type] = node_embed[node_type][nid]
    #     return emb

    def write_summary(self, scores, batch_step):
        # main_name = os.path.basename(self.model_base_path)
        for var, val in scores.items():
            # self.summary_writer.add_scalar(f"{main_name}/{var}", val, batch_step)
            self.summary_writer.add_scalar(var, val, batch_step)
        # self.summary_writer.add_scalars(main_name, scores, batch_step)

    def write_hyperparams(self, scores, epoch):
        params = copy(self.model_params)
        params["epoch"] = epoch
        main_name = os.path.basename(self.model_base_path)
        params = {
            k: v
            for k, v in params.items()
            if type(v) in {int, float, str, bool, torch.Tensor}
        }

        main_name = os.path.basename(self.model_base_path)
        scores = {f"h_metric/{k}": v for k, v in scores.items()}
        self.summary_writer.add_hparams(params,
                                        scores,
                                        run_name=f"h_metric/{epoch}")

    def _extract_embed(self, input_nodes):
        emb = {}
        for node_type, nid in input_nodes.items():
            emb[node_type] = self.node_embedder(
                node_type=node_type,
                node_ids=nid,
                train_embeddings=self.pretraining).to(self.device)
        return emb

    def _logits_batch(self, input_nodes, blocks):

        cumm_logits = []

        if self.use_types:
            # emb = self._extract_embed(self.graph_model.node_embed(), input_nodes)
            emb = self._extract_embed(input_nodes)
        else:
            if self.ntypes is not None:
                # single node type
                key = next(iter(self.ntypes))
                input_nodes = {key: input_nodes}
                # emb = self._extract_embed(self.graph_model.node_embed(), input_nodes)
                emb = self._extract_embed(input_nodes)
            else:
                emb = self.node_embedder(node_ids=input_nodes,
                                         train_embeddings=self.pretraining)
                # emb = self.graph_model.node_embed()[input_nodes]

        logits = self.graph_model(emb, blocks)

        if self.use_types:
            for ntype in self.graph_model.g.ntypes:

                logits_ = logits.get(ntype, None)
                if logits_ is None:
                    continue

                cumm_logits.append(logits_)
        else:
            if self.ntypes is not None:
                # single node type
                key = next(iter(self.ntypes))
                logits_ = logits[key]
            else:
                logits_ = logits

            cumm_logits.append(logits_)

        return torch.cat(cumm_logits)

    def seeds_to_global(self, seeds):
        if type(seeds) is dict:
            indices = [
                self.graph_model.g.nodes[ntype].data["global_graph_id"][
                    seeds[ntype]] for ntype in seeds
            ]
            return torch.cat(indices, dim=0)
        else:
            return seeds

    def _logits_embedder(self,
                         node_embeddings,
                         elem_embedder,
                         link_predictor,
                         seeds,
                         negative_factor=1):
        k = negative_factor
        indices = self.seeds_to_global(seeds)
        batch_size = len(indices)

        node_embeddings_batch = node_embeddings
        element_embeddings = elem_embedder(elem_embedder[indices.tolist()].to(
            self.device))

        positive_batch = torch.cat([node_embeddings_batch, element_embeddings],
                                   1)
        labels_pos = torch.ones(batch_size, dtype=torch.long)

        node_embeddings_neg_batch = node_embeddings_batch.repeat(k, 1)
        negative_random = elem_embedder(
            elem_embedder.sample_negative(batch_size * k).to(self.device))

        negative_batch = torch.cat(
            [node_embeddings_neg_batch, negative_random], 1)
        labels_neg = torch.zeros(batch_size * k, dtype=torch.long)

        batch = torch.cat([positive_batch, negative_batch], 0)
        labels = torch.cat([labels_pos, labels_neg], 0).to(self.device)

        logits = link_predictor(batch)

        return logits, labels

    def _handle_non_unique(self, non_unique_ids):
        id_list = non_unique_ids.tolist()
        unique_ids = list(set(id_list))
        new_position = dict(zip(unique_ids, range(len(unique_ids))))
        slice_map = torch.tensor(list(map(lambda x: new_position[x], id_list)),
                                 dtype=torch.long)
        return torch.tensor(unique_ids, dtype=torch.long), slice_map

    def _logits_nodes(self,
                      node_embeddings,
                      elem_embedder,
                      link_predictor,
                      create_dataloader,
                      src_seeds,
                      negative_factor=1):
        k = negative_factor
        indices = self.seeds_to_global(src_seeds)
        batch_size = len(indices)

        node_embeddings_batch = node_embeddings
        next_call_indices = elem_embedder[
            indices.tolist()]  # this assumes indices is torch tensor

        # dst targets are not unique
        unique_dst, slice_map = self._handle_non_unique(next_call_indices)
        assert unique_dst[slice_map].tolist() == next_call_indices.tolist()

        dataloader = create_dataloader(unique_dst)
        input_nodes, dst_seeds, blocks = next(iter(dataloader))
        blocks = [blk.to(self.device) for blk in blocks]
        assert dst_seeds.shape == unique_dst.shape
        assert dst_seeds.tolist() == unique_dst.tolist()
        unique_dst_embeddings = self._logits_batch(
            input_nodes, blocks)  # use_types, ntypes)
        next_call_embeddings = unique_dst_embeddings[slice_map.to(self.device)]
        positive_batch = torch.cat(
            [node_embeddings_batch, next_call_embeddings], 1)
        labels_pos = torch.ones(batch_size, dtype=torch.long)

        node_embeddings_neg_batch = node_embeddings_batch.repeat(k, 1)
        negative_indices = torch.tensor(
            elem_embedder.sample_negative(batch_size * k), dtype=torch.long
        )  # embeddings are sampled from 3/4 unigram distribution
        unique_negative, slice_map = self._handle_non_unique(negative_indices)
        assert unique_negative[slice_map].tolist() == negative_indices.tolist()

        dataloader = create_dataloader(unique_negative)
        input_nodes, dst_seeds, blocks = next(iter(dataloader))
        blocks = [blk.to(self.device) for blk in blocks]
        assert dst_seeds.shape == unique_negative.shape
        assert dst_seeds.tolist() == unique_negative.tolist()
        unique_negative_random = self._logits_batch(
            input_nodes, blocks)  # use_types, ntypes)
        negative_random = unique_negative_random[slice_map.to(self.device)]
        negative_batch = torch.cat(
            [node_embeddings_neg_batch, negative_random], 1)
        labels_neg = torch.zeros(batch_size * k, dtype=torch.long)

        batch = torch.cat([positive_batch, negative_batch], 0)
        labels = torch.cat([labels_pos, labels_neg], 0).to(self.device)

        logits = link_predictor(batch)

        return logits, labels

    def _logits_node_name(self, input_nodes, seeds, blocks):
        src_embs = self._logits_batch(input_nodes, blocks)
        logits, labels = self._logits_embedder(
            src_embs,
            self.ee_node_name,
            self.lp_node_name,
            seeds,
            negative_factor=self.neg_sampling_factor)
        return logits, labels

    def _logits_var_use(self, input_nodes, seeds, blocks):
        src_embs = self._logits_batch(input_nodes, blocks)
        logits, labels = self._logits_embedder(
            src_embs,
            self.ee_var_use,
            self.lp_var_use,
            seeds,
            negative_factor=self.neg_sampling_factor)
        return logits, labels

    def _logits_api_call(self, input_nodes, seeds, blocks):
        src_embs = self._logits_batch(input_nodes, blocks)
        logits, labels = self._logits_nodes(
            src_embs,
            self.ee_api_call,
            self.lp_api_call,
            self._create_api_call_loader,
            seeds,
            negative_factor=self.neg_sampling_factor)
        return logits, labels

    def _get_training_targets(self):
        if hasattr(self.graph_model.g, 'ntypes'):
            self.ntypes = self.graph_model.g.ntypes
            # labels = {ntype: self.graph_model.g.nodes[ntype].data['labels'] for ntype in self.ntypes}
            self.use_types = True

            if len(self.graph_model.g.ntypes) == 1:
                # key = next(iter(labels.keys()))
                # labels = labels[key]
                self.use_types = False

            train_idx = {
                ntype: torch.nonzero(
                    self.graph_model.g.nodes[ntype].data['train_mask'],
                    as_tuple=False).squeeze()
                for ntype in self.ntypes
            }
            val_idx = {
                ntype:
                torch.nonzero(self.graph_model.g.nodes[ntype].data['val_mask'],
                              as_tuple=False).squeeze()
                for ntype in self.ntypes
            }
            test_idx = {
                ntype: torch.nonzero(
                    self.graph_model.g.nodes[ntype].data['test_mask'],
                    as_tuple=False).squeeze()
                for ntype in self.ntypes
            }
        else:
            self.ntypes = None
            # labels = g.ndata['labels']
            train_idx = self.graph_model.g.ndata['train_mask']
            val_idx = self.graph_model.g.ndata['val_mask']
            test_idx = self.graph_model.g.ndata['test_mask']
            self.use_types = False

        return train_idx, val_idx, test_idx

    def _evaluate_embedder(self, ee, lp, loader, neg_sampling_factor=1):

        total_loss = 0
        total_acc = 0
        count = 0

        for input_nodes, seeds, blocks in loader:
            blocks = [blk.to(self.device) for blk in blocks]

            src_embs = self._logits_batch(input_nodes, blocks)
            logits, labels = self._logits_embedder(src_embs, ee, lp, seeds,
                                                   neg_sampling_factor)

            logp = nn.functional.log_softmax(logits, 1)
            loss = nn.functional.cross_entropy(logp, labels)
            acc = compute_accuracy(logp.argmax(dim=1), labels)

            total_loss += loss.item()
            total_acc += acc
            count += 1
        return total_loss / count, total_acc / count

    def _evaluate_nodes(self,
                        ee,
                        lp,
                        create_api_call_loader,
                        loader,
                        neg_sampling_factor=1):

        total_loss = 0
        total_acc = 0
        count = 0

        for input_nodes, seeds, blocks in loader:
            blocks = [blk.to(self.device) for blk in blocks]

            src_embs = self._logits_batch(input_nodes, blocks)
            logits, labels = self._logits_nodes(src_embs, ee, lp,
                                                create_api_call_loader, seeds,
                                                neg_sampling_factor)

            logp = nn.functional.log_softmax(logits, 1)
            loss = nn.functional.cross_entropy(logp, labels)
            acc = compute_accuracy(logp.argmax(dim=1), labels)

            total_loss += loss.item()
            total_acc += acc
            count += 1
        return total_loss / count, total_acc / count

    def _evaluate_objectives(self, loader_node_name, loader_var_use,
                             loader_api_call, neg_sampling_factor):

        node_name_loss, node_name_acc = self._evaluate_embedder(
            self.ee_node_name,
            self.lp_node_name,
            loader_node_name,
            neg_sampling_factor=neg_sampling_factor)

        var_use_loss, var_use_acc = self._evaluate_embedder(
            self.ee_var_use,
            self.lp_var_use,
            loader_var_use,
            neg_sampling_factor=neg_sampling_factor)

        api_call_loss, api_call_acc = self._evaluate_nodes(
            self.ee_api_call,
            self.lp_api_call,
            self._create_api_call_loader,
            loader_api_call,
            neg_sampling_factor=neg_sampling_factor)

        loss = node_name_loss + var_use_loss + api_call_loss

        return loss, node_name_acc, var_use_acc, api_call_acc

    def _idx_len(self, idx):
        if isinstance(idx, dict):
            length = 0
            for key in idx:
                length += len(idx[key])
        else:
            length = len(idx)
        return length

    def _get_loaders(self, train_idx, val_idx, test_idx, batch_size):
        layers = self.graph_model.num_layers
        # train sampler
        sampler = dgl.dataloading.MultiLayerNeighborSampler(
            [self.sampling_neighbourhood_size] * layers)
        loader = dgl.dataloading.NodeDataLoader(self.graph_model.g,
                                                train_idx,
                                                sampler,
                                                batch_size=batch_size,
                                                shuffle=False,
                                                num_workers=0)

        # validation sampler
        # we do not use full neighbor to save computation resources
        val_sampler = dgl.dataloading.MultiLayerNeighborSampler(
            [self.sampling_neighbourhood_size] * layers)
        val_loader = dgl.dataloading.NodeDataLoader(self.graph_model.g,
                                                    val_idx,
                                                    val_sampler,
                                                    batch_size=batch_size,
                                                    shuffle=False,
                                                    num_workers=0)

        # we do not use full neighbor to save computation resources
        test_sampler = dgl.dataloading.MultiLayerNeighborSampler(
            [self.sampling_neighbourhood_size] * layers)
        test_loader = dgl.dataloading.NodeDataLoader(self.graph_model.g,
                                                     test_idx,
                                                     test_sampler,
                                                     batch_size=batch_size,
                                                     shuffle=False,
                                                     num_workers=0)

        return loader, val_loader, test_loader

    def _create_loaders(self, train_idx, val_idx, test_idx):

        train_idx_node_name, val_idx_node_name, test_idx_node_name = self.ee_node_name.create_idx_pools(
            train_idx=train_idx, val_idx=val_idx, test_idx=test_idx)
        train_idx_var_use, val_idx_var_use, test_idx_var_use = self.ee_var_use.create_idx_pools(
            train_idx=train_idx, val_idx=val_idx, test_idx=test_idx)
        train_idx_api_call, val_idx_api_call, test_idx_api_call = self.ee_api_call.create_idx_pools(
            train_idx=train_idx, val_idx=val_idx, test_idx=test_idx)

        logging.info(
            f"Pool sizes : train {self._idx_len(train_idx_node_name)}, "
            f"val {self._idx_len(val_idx_node_name)}, "
            f"test {self._idx_len(test_idx_node_name)}.")
        logging.info(f"Pool sizes : train {self._idx_len(train_idx_var_use)}, "
                     f"val {self._idx_len(val_idx_var_use)}, "
                     f"test {self._idx_len(test_idx_var_use)}.")
        logging.info(
            f"Pool sizes : train {self._idx_len(train_idx_api_call)}, "
            f"val {self._idx_len(val_idx_api_call)}, "
            f"test {self._idx_len(test_idx_api_call)}.")

        self.loader_node_name, self.val_loader_node_name, self.test_loader_node_name = self._get_loaders(
            train_idx=train_idx_node_name,
            val_idx=val_idx_node_name,
            test_idx=test_idx_node_name,
            batch_size=self.batch_size  # batch_size_node_name
        )
        self.loader_var_use, self.val_loader_var_use, self.test_loader_var_use = self._get_loaders(
            train_idx=train_idx_var_use,
            val_idx=val_idx_var_use,
            test_idx=test_idx_var_use,
            batch_size=self.batch_size  # batch_size_var_use
        )
        self.loader_api_call, self.val_loader_api_call, self.test_loader_api_call = self._get_loaders(
            train_idx=train_idx_api_call,
            val_idx=val_idx_api_call,
            test_idx=test_idx_api_call,
            batch_size=self.batch_size  # batch_size_api_call
        )

    def _create_api_call_loader(self, indices):
        sampler = dgl.dataloading.MultiLayerNeighborSampler(
            [self.sampling_neighbourhood_size] * self.graph_model.num_layers)
        return dgl.dataloading.NodeDataLoader(self.graph_model.g,
                                              indices,
                                              sampler,
                                              batch_size=len(indices),
                                              num_workers=0)

    def _create_optimizer(self):
        # optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        optimizer = torch.optim.Adam(
            [
                {
                    'params': self.graph_model.parameters()
                },
                {
                    'params': self.node_embedder.parameters()
                },
                {
                    'params': self.ee_node_name.parameters()
                },
                {
                    'params': self.ee_var_use.parameters()
                },
                # {'params': self.ee_api_call.parameters()},
                {
                    'params': self.lp_node_name.parameters()
                },
                {
                    'params': self.lp_var_use.parameters()
                },
                {
                    'params': self.lp_api_call.parameters()
                },
            ],
            lr=self.lr)
        return optimizer

    def train_all(self):
        """
        Training procedure for the model with node classifier
        :return:
        """

        for epoch in range(self.epoch, self.epochs):
            self.epoch = epoch

            start = time()

            for i, ((input_nodes_node_name, seeds_node_name, blocks_node_name),
                    (input_nodes_var_use, seeds_var_use, blocks_var_use),
                    (input_nodes_api_call, seeds_api_call, blocks_api_call)) in \
                    enumerate(zip(
                        self.loader_node_name,
                        self.loader_var_use,
                        self.loader_api_call)):

                blocks_node_name = [
                    blk.to(self.device) for blk in blocks_node_name
                ]
                blocks_var_use = [
                    blk.to(self.device) for blk in blocks_var_use
                ]
                blocks_api_call = [
                    blk.to(self.device) for blk in blocks_api_call
                ]

                logits_node_name, labels_node_name = self._logits_node_name(
                    input_nodes_node_name, seeds_node_name, blocks_node_name)

                logits_var_use, labels_var_use = self._logits_var_use(
                    input_nodes_var_use, seeds_var_use, blocks_var_use)

                logits_api_call, labels_api_call = self._logits_api_call(
                    input_nodes_api_call, seeds_api_call, blocks_api_call)

                train_acc_node_name = compute_accuracy(
                    logits_node_name.argmax(dim=1), labels_node_name)
                train_acc_var_use = compute_accuracy(
                    logits_var_use.argmax(dim=1), labels_var_use)
                train_acc_api_call = compute_accuracy(
                    logits_api_call.argmax(dim=1), labels_api_call)

                train_logits = torch.cat(
                    [logits_node_name, logits_var_use, logits_api_call], 0)
                train_labels = torch.cat(
                    [labels_node_name, labels_var_use, labels_api_call], 0)

                logp = nn.functional.log_softmax(train_logits, 1)
                loss = nn.functional.nll_loss(logp, train_labels)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                self.write_summary(
                    {
                        "Loss": loss,
                        "Accuracy/train/node_name_vs_batch":
                        train_acc_node_name,
                        "Accuracy/train/var_use_vs_batch": train_acc_var_use,
                        "Accuracy/train/api_call_vs_batch": train_acc_api_call
                    }, self.batch)
                self.batch += 1

            self.eval()

            with torch.set_grad_enabled(False):

                _, val_acc_node_name, val_acc_var_use, val_acc_api_call = self._evaluate_objectives(
                    self.val_loader_node_name, self.val_loader_var_use,
                    self.val_loader_api_call, self.neg_sampling_factor)

                _, test_acc_node_name, test_acc_var_use, test_acc_api_call = self._evaluate_objectives(
                    self.test_loader_node_name, self.test_loader_var_use,
                    self.test_loader_api_call, self.neg_sampling_factor)

            self.train()

            end = time()

            self.best_score.track_best(epoch=epoch,
                                       loss=loss.item(),
                                       train_acc_node_name=train_acc_node_name,
                                       val_acc_node_name=val_acc_node_name,
                                       test_acc_node_name=test_acc_node_name,
                                       train_acc_var_use=train_acc_var_use,
                                       val_acc_var_use=val_acc_var_use,
                                       test_acc_var_use=test_acc_var_use,
                                       train_acc_api_call=train_acc_api_call,
                                       val_acc_api_call=val_acc_api_call,
                                       test_acc_api_call=test_acc_api_call,
                                       time=end - start)

            if self.do_save:
                self.save_checkpoint(self.model_base_path)

            self.write_summary(
                {
                    "Accuracy/test/node_name_vs_batch": test_acc_node_name,
                    "Accuracy/test/var_use_vs_batch": test_acc_var_use,
                    "Accuracy/test/api_call_vs_batch": test_acc_api_call,
                    "Accuracy/val/node_name_vs_batch": val_acc_node_name,
                    "Accuracy/val/var_use_vs_batch": val_acc_var_use,
                    "Accuracy/val/api_call_vs_batch": val_acc_api_call
                }, self.batch)

            self.write_hyperparams(
                {
                    "Loss/train_vs_epoch": loss,
                    "Accuracy/train/node_name_vs_epoch": train_acc_node_name,
                    "Accuracy/train/var_use_vs_epoch": train_acc_var_use,
                    "Accuracy/train/api_call_vs_epoch": train_acc_api_call,
                    "Accuracy/test/node_name_vs_epoch": test_acc_node_name,
                    "Accuracy/test/var_use_vs_epoch": test_acc_var_use,
                    "Accuracy/test/api_call_vs_epoch": test_acc_api_call,
                    "Accuracy/val/node_name_vs_epoch": val_acc_node_name,
                    "Accuracy/val/var_use_vs_epoch": val_acc_var_use,
                    "Accuracy/val/api_call_vs_epoch": val_acc_api_call
                }, self.epoch)

            self.lr_scheduler.step()

    def save_checkpoint(self,
                        checkpoint_path=None,
                        checkpoint_name=None,
                        **kwargs):

        checkpoint_path = join(checkpoint_path, "saved_state.pt")

        param_dict = {
            'graph_model': self.graph_model.state_dict(),
            'node_embedder': self.node_embedder.state_dict(),
            'ee_node_name': self.ee_node_name.state_dict(),
            'ee_var_use': self.ee_var_use.state_dict(),
            # 'ee_api_call': self.ee_api_call.state_dict(),
            "lp_node_name": self.lp_node_name.state_dict(),
            "lp_var_use": self.lp_var_use.state_dict(),
            "lp_api_call": self.lp_api_call.state_dict(),
            "epoch": self.epoch,
            "batch": self.batch
        }

        if len(kwargs) > 0:
            param_dict.update(kwargs)

        torch.save(param_dict, checkpoint_path)

    def restore_from_checkpoint(self, checkpoint_path):
        checkpoint = torch.load(join(checkpoint_path, "saved_state.pt"))
        self.graph_model.load_state_dict(checkpoint['graph_model'])
        self.ee_node_name.load_state_dict(checkpoint['ee_node_name'])
        self.ee_var_use.load_state_dict(checkpoint['ee_var_use'])
        # self.ee_api_call.load_state_dict(checkpoint['ee_api_call'])
        self.lp_node_name.load_state_dict(checkpoint['lp_node_name'])
        self.lp_var_use.load_state_dict(checkpoint['lp_var_use'])
        self.lp_api_call.load_state_dict(checkpoint['lp_api_call'])
        self.epoch = checkpoint['epoch']
        self.batch = checkpoint['batch']
        logging.info(f"Restored from epoch {checkpoint['epoch']}")

    def final_evaluation(self):

        with torch.set_grad_enabled(False):

            loss, train_acc_node_name, train_acc_var_use, train_acc_api_call = self._evaluate_objectives(
                self.loader_node_name, self.loader_var_use,
                self.loader_api_call, 1)

            _, val_acc_node_name, val_acc_var_use, val_acc_api_call = self._evaluate_objectives(
                self.val_loader_node_name, self.val_loader_var_use,
                self.val_loader_api_call, 1)

            _, test_acc_node_name, test_acc_var_use, test_acc_api_call = self._evaluate_objectives(
                self.test_loader_node_name, self.test_loader_var_use,
                self.test_loader_api_call, 1)

        scores = {
            # "loss": loss.item(),
            "train_acc_node_name": train_acc_node_name,
            "val_acc_node_name": val_acc_node_name,
            "test_acc_node_name": test_acc_node_name,
            "train_acc_var_use": train_acc_var_use,
            "val_acc_var_use": val_acc_var_use,
            "test_acc_var_use": test_acc_var_use,
            "train_acc_api_call": train_acc_api_call,
            "val_acc_api_call": val_acc_api_call,
            "test_acc_api_call": test_acc_api_call,
        }

        print(
            f'Final Eval : node name Train Acc {scores["train_acc_node_name"]:.4f}, '
            f'node name Val Acc {scores["val_acc_node_name"]:.4f}, '
            f'node name Test Acc {scores["test_acc_node_name"]:.4f}, '
            f'var use Train Acc {scores["train_acc_var_use"]:.4f}, '
            f'var use Val Acc {scores["val_acc_var_use"]:.4f}, '
            f'var use Test Acc {scores["test_acc_var_use"]:.4f}, '
            f'api call Train Acc {scores["train_acc_api_call"]:.4f}, '
            f'api call Val Acc {scores["val_acc_api_call"]:.4f}, '
            f'api call Test Acc {scores["test_acc_api_call"]:.4f}')

        return scores

    def eval(self):
        self.graph_model.eval()
        self.ee_node_name.eval()
        self.ee_var_use.eval()
        # self.ee_api_call.eval()
        self.lp_node_name.eval()
        self.lp_var_use.eval()
        self.lp_api_call.eval()

    def train(self):
        self.graph_model.train()
        self.ee_node_name.train()
        self.ee_var_use.train()
        # self.ee_api_call.eval()
        self.lp_node_name.train()
        self.lp_var_use.train()
        self.lp_api_call.train()

    def to(self, device):
        self.graph_model.to(device)
        self.ee_node_name.to(device)
        self.ee_var_use.to(device)
        # self.ee_api_call.to(device)
        self.lp_node_name.to(device)
        self.lp_var_use.to(device)
        self.lp_api_call.to(device)

    def get_embeddings(self):
        # self.graph_model.g.nodes["function"].data.keys()
        nodes = self.graph_model.g.nodes
        node_embs = {
            ntype: self.node_embedder(node_type=ntype,
                                      node_ids=nodes[ntype].data['typed_id'],
                                      train_embeddings=False)
            for ntype in self.graph_model.g.ntypes
        }

        h = self.graph_model.inference(batch_size=256,
                                       device='cpu',
                                       num_workers=0,
                                       x=node_embs)

        original_id = []
        global_id = []
        embeddings = []
        for ntype in self.graph_model.g.ntypes:
            embeddings.append(h[ntype])
            original_id.extend(nodes[ntype].data['original_id'].tolist())
            global_id.extend(nodes[ntype].data['global_graph_id'].tolist())

        embeddings = torch.cat(embeddings, dim=0).detach().numpy()

        return [Embedder(dict(zip(original_id, global_id)), embeddings)]
    def train(init_dict,
              n_iterations,
              transformer,
              sim_data_node=None,
              is_hsearch=False,
              tb_writer=None,
              print_every=10):
        start = time.time()
        predictor = init_dict['model']
        data_loaders = init_dict['data_loaders']
        optimizer = init_dict['optimizer']
        metrics = init_dict['metrics']
        lr_sch = ExponentialLR(optimizer, gamma=0.98)
        if tb_writer:
            tb_writer = tb_writer()
        best_model_wts = predictor.state_dict()
        best_score = -10000
        best_epoch = -1
        terminate_training = False
        n_epochs = n_iterations // len(data_loaders['train'])
        criterion = torch.nn.MSELoss()

        # Since during hyperparameter search values that could cause CUDA memory exception could be sampled
        # we want to ignore such values and find others that are workable within the memory constraints.
        with contextlib.suppress(Exception if is_hsearch else DummyException):
            for epoch in range(n_epochs):
                eval_scores = []
                for phase in ['train', 'val' if is_hsearch else 'test']:
                    if phase == 'train':
                        predictor.train()
                    else:
                        predictor.eval()

                    losses = []
                    metrics_dict = defaultdict(list)
                    for batch in tqdm(
                            data_loaders[phase],
                            desc=f'Phase: {phase}, epoch={epoch + 1}/{n_epochs}'
                    ):
                        batch = np.array(batch)
                        x = batch[:, 0]
                        y_true = batch[:, 1]
                        with torch.set_grad_enabled(phase == 'train'):
                            y_true = torch.from_numpy(
                                y_true.reshape(-1, 1).astype(
                                    np.float)).float().to(device)
                            y_pred = predictor(x)
                            loss = criterion(y_pred, y_true)
                            losses.append(loss.item())

                        # Perform evaluation using the given metrics
                        eval_dict = {}
                        score = ExpertTrainer.evaluate(
                            eval_dict,
                            transformer.inverse_transform(
                                y_true.cpu().detach().numpy()),
                            transformer.inverse_transform(
                                y_pred.cpu().detach().numpy()), metrics)
                        for m in eval_dict:
                            if m in metrics:
                                metrics_dict[m].append(eval_dict[m])
                            else:
                                metrics_dict[m] = [eval_dict[m]]

                        # Update weights
                        if phase == 'train':
                            optimizer.zero_grad()
                            loss.backward()
                            optimizer.step()
                        else:
                            eval_scores.append(score)
                    metrics_dict = {
                        k: np.mean(metrics_dict[k])
                        for k in metrics_dict
                    }
                    if epoch % print_every == 0:
                        print(
                            f'{phase}: epoch={epoch + 1}/{n_epochs}, loss={np.mean(losses)}, metrics={metrics_dict}'
                        )
                    if phase == 'train':
                        lr_sch.step()
                # Checkpoint
                score = np.mean(eval_scores)
                if score > best_score:
                    best_score = score
                    best_model_wts = copy.deepcopy(predictor.state_dict())
                    best_epoch = epoch
        predictor.load_state_dict(best_model_wts)
        print(f'Time elapsed: {time_since(start)}')
        return {'model': predictor, 'score': best_score, 'epoch': best_epoch}
Exemple #20
0
        output, c_D, _ = D(fake)
        lossG = bce_loss(output, y_real)
        lossinfo = mse_loss(c_D, c)
        D_G_z2 = output.data.mean()
        return lossG, lossinfo, D_G_z2

    # start training
    D.train()
    for epoch in range(FG.num_epochs):
        printers['lr']('D', epoch, optimizerD.param_groups[0]['lr'])
        printers['lr']('G',  epoch, optimizerG.param_groups[0]['lr'])
        printers['lr']('info', epoch, optimizerinfo.param_groups[0]['lr'])
        timer.tic()
        #lr schedular
        if (epoch+1)%100 == 0:
            schedularD.step()
            schedularG.step()
            schedularinfo.step()

        G.train()
        for step, data in enumerate(trainloader):
            D.zero_grad()
            G.zero_grad()
            x = data['image'].float().cuda(device, non_blocking=True)
            y = data['target'].float().cuda(device, non_blocking=True)
            batch_size = x.size(0)

            # set veriables
            z = torch.rand(batch_size, FG.z_dim).float().cuda(device, non_blocking=True)
            c = torch.from_numpy(np.random.uniform(-1, 1, size=(batch_size,\
                                   c_code))).float().cuda(device, non_blocking=True)
Exemple #21
0
def train(data_dir, net_path=None, save_dir='pretrained'):
    #从文件中读取图像数据集
    seq_dataset = GOT10k(data_dir, subset='train', return_meta=False)
    #定义图像预处理方法
    transforms = SiamFCTransforms(
        exemplar_sz=cfg.exemplar_sz,  #127
        instance_sz=cfg.instance_sz,  #255
        context=cfg.context)  #0.5
    #从读取的数据集每个视频序列配对训练图像并进行预处理,裁剪等
    train_dataset = GOT10kDataset(seq_dataset, transforms)
    #加载训练数据集
    loader_dataset = DataLoader(
        dataset=train_dataset,
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=cfg.num_workers,
        pin_memory=True,
        drop_last=True,
    )
    #初始化训练网络
    cuda = torch.cuda.is_available()  #支持GPU为True
    device = torch.device('cuda:0' if cuda else 'cpu')  #cuda设备号为0
    model = AlexNet(init_weight=True)
    corr = _corr()
    model = model.to(device)
    corr = corr.to(device)
    # 设置损失函数和标签
    logist_loss = BalancedLoss()
    labels = _create_labels(
        size=[cfg.batch_size, 1, cfg.response_sz - 2, cfg.response_sz - 2])
    labels = torch.from_numpy(labels).to(device).float()
    #建立优化器,设置指数变化的学习率
    optimizer = optim.SGD(
        model.parameters(),
        lr=cfg.initial_lr,  #初始化的学习率,后续会不断更新
        weight_decay=cfg.weight_decay,  #λ=5e-4,正则化
        momentum=cfg.momentum)  #v(now)=−dx∗lr+v(last)∗momemtum
    gamma = np.power(  #np.power(a,b) 返回a^b
        cfg.ultimate_lr / cfg.initial_lr, 1.0 / cfg.epoch_num)
    lr_scheduler = ExponentialLR(optimizer,
                                 gamma)  #指数形式衰减,lr=initial_lr*(gamma^epoch)
    """————————————————————————开始训练——————————————————————————"""
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    start_epoch = 1
    #接着上一次训练,提取训练结束保存的net、optimizer、epoch参数
    if net_path is not None:
        checkpoint = torch.load(net_path)
        if 'epoch' in checkpoint:
            start_epoch = checkpoint['epoch'] + 1
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
        else:
            model.load_state_dict(checkpoint)
        del checkpoint
        torch.cuda.empty_cache()  #缓存清零
        print("loaded checkpoint!!!")
    for epoch in range(start_epoch, cfg.epoch_num + 1):
        model.train()
        #遍历训练集
        for it, batch in enumerate(tqdm(loader_dataset)):
            z = batch[0].to(device,
                            non_blocking=cuda)  # z.shape=([8,3,127,127])
            x = batch[1].to(device,
                            non_blocking=cuda)  # x.shape=([8,3,239,239])
            #输入网络后通过损失函数
            z, x = model(z), model(x)
            responses = corr(
                z, x
            ) * cfg.out_reduce  # 返回的是heatmap的响应表15x15  因为x是239x239 [8,1,15,15]
            loss = logist_loss(responses, labels)
            #back propagation反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # if (it+1) % 20 ==0:
            print('Epoch: {}[{}/{}]    Loss: {:.5f}    lr: {:.2e}'.format(
                epoch, it + 1, len(loader_dataset), loss.item(),
                optimizer.param_groups[0]['lr']))
        #更新学习率 (每个epoch)
        lr_scheduler.step()
        #save checkpoint 做玩1个epoch后保存1个输出模型
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        save_path = os.path.join(save_dir, 'siamfc_alexnet_e%d.pth' % (epoch))

        torch.save(
            {
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }, save_path)
Exemple #22
0
class TrackerSiamFC(Tracker):
    def __init__(self, imagefile, region):
        super(TrackerSiamFC, self).__init__(name='SiamFC',
                                            is_deterministic=True)
        self.cfg = self.parse_args()

        # setup GPU device if available
        self.cuda = torch.cuda.is_available()
        self.device = torch.device('cuda:0' if self.cuda else 'cpu')

        # setup model
        self.net = SiamFC()
        net_path = "/home/user/siamfc/pretrained/siamfc_new/model_e1_BEST.pth"
        if net_path is not None:
            self.net.load_state_dict(
                torch.load(net_path,
                           map_location=lambda storage, loc: storage))
        self.net = self.net.to(self.device)

        # setup optimizer
        self.optimizer = optim.SGD(self.net.parameters(),
                                   lr=self.cfg.initial_lr,
                                   weight_decay=self.cfg.weight_decay,
                                   momentum=self.cfg.momentum)

        # setup lr scheduler
        self.lr_scheduler = ExponentialLR(self.optimizer,
                                          gamma=self.cfg.lr_decay)
        self.cf_influence = 0.11
        bbox = convert_bbox_format(region, 'center-based')
        bbox = (bbox.x, bbox.y, bbox.width, bbox.height)
        image = Image.open(imagefile)
        self.init(bbox, image)

    def parse_args(self):
        # default parameters
        cfg = {
            # inference parameters
            'exemplar_sz': 127,
            'instance_sz': 255,
            'context': 0.5,
            'scale_num': 3,
            'scale_step': 1.0375,
            'scale_lr': 0.59,
            'scale_penalty': 0.9745,
            'window_influence': 0.176,  #change here 0.176 -> 0.1
            'response_sz': 17,
            'response_up': 16,
            'total_stride': 8,
            'adjust_scale': 0.001,
            # train parameters
            'initial_lr': 0.001,  #change here 0.01->0.001
            'lr_decay': 0.8685113737513527,
            'weight_decay': 5e-4,
            'momentum': 0.9,
            'r_pos': 16,
            'r_neg': 0
        }

        return namedtuple('GenericDict', cfg.keys())(**cfg)

    def init(self, image, box):
        self.net.is_train = False
        image = np.asarray(image)

        # convert box to 0-indexed and center based [y, x, h, w]
        box = np.array([
            box[1] - 1 + (box[3] - 1) / 2, box[0] - 1 +
            (box[2] - 1) / 2, box[3], box[2]
        ],
                       dtype=np.float32)
        self.center, self.target_sz = box[:2], box[2:]

        # create hanning window
        self.upscale_sz = self.cfg.response_up * self.cfg.response_sz
        self.hann_window = np.outer(np.hanning(self.upscale_sz),
                                    np.hanning(self.upscale_sz))
        self.hann_window /= self.hann_window.sum()

        # search scale factors
        self.scale_factors = self.cfg.scale_step**np.linspace(
            -(self.cfg.scale_num // 2), self.cfg.scale_num // 2,
            self.cfg.scale_num)

        # exemplar and search sizes
        context = self.cfg.context * np.sum(self.target_sz)
        self.z_sz = np.sqrt(np.prod(self.target_sz + context))
        self.x_sz = self.z_sz * \
            self.cfg.instance_sz / self.cfg.exemplar_sz

        # exemplar image
        self.avg_color = np.mean(image, axis=(0, 1))
        #TODO: change here
        exemplar_image = self._crop_and_resize(image,
                                               self.center,
                                               self.z_sz,
                                               out_size=self.cfg.exemplar_sz,
                                               pad_color=self.avg_color)
        exemplar_image_cf = self._crop_and_resize(
            image,
            self.center,
            self.x_sz,
            out_size=self.cfg.instance_sz,
            pad_color=self.avg_color)
        # exemplar features
        exemplar_image = Image_to_Tensor(exemplar_image).to(
            self.device).unsqueeze(0)
        exemplar_image_cf = Image_to_Tensor(exemplar_image_cf).to(
            self.device).unsqueeze(0)
        #exemplar_image = torch.from_numpy(exemplar_image).to(
        #self.device).permute([2, 0, 1]).unsqueeze(0).float()
        with torch.set_grad_enabled(False):
            self.net.eval()
            _, self.kernel = self.net.features(exemplar_image)

            self.kernel = self.kernel.repeat(3, 1, 1, 1)
            self.net.update(exemplar_image_cf)

    def track(self, imagefile):
        image = Image.open(imagefile)
        self.update(image)
        bbox = Rectangle(self.center[0], self.center[1], self.target_sz[0],
                         self.target_sz[1])
        bbox = convert_bbox_format(bbox, 'top-left-based')
        return bbox

    def update(self, image):
        self.net.is_train = False
        image = np.asarray(image)

        # search images
        instance_images = [
            self._crop_and_resize(image,
                                  self.center,
                                  self.x_sz * f,
                                  out_size=self.cfg.instance_sz,
                                  pad_color=self.avg_color)
            for f in self.scale_factors
        ]
        instance_images = [
            Image_to_Tensor(f).to(self.device).unsqueeze(0).squeeze(0)
            for f in instance_images
        ]
        instance_images = torch.stack(instance_images)

        # responses
        with torch.set_grad_enabled(False):
            self.net.eval()
            #TODO: change here
            #_, instances = self.net.features(instance_images)
            #responses = F.conv2d(instances, self.kernel) * 0.001
            responses, cf_responses = self.net.get_response(
                self.kernel, instance_images)
        responses = responses.squeeze(1).cpu().numpy()
        cf_responses = cf_responses.squeeze(1).cpu().numpy()
        #print(np.unravel_index(cf_responses[1].argmax(), cf_responses[1].shape))
        cf_responses = np.roll(cf_responses,
                               int(np.floor(float(251) / 2.) - 1),
                               axis=1)
        cf_responses = np.roll(cf_responses,
                               int(np.floor(float(251) / 2.) - 1),
                               axis=2)
        #print(np.unravel_index(cf_responses[1].argmax(), cf_responses[1].shape))
        #cv2.imshow("tset", cf_responses[1])
        #cv2.waitKey(1000)
        # upsample responses and penalize scale changes
        #cf-----------------------------------------------------------------------
        cf_responses = np.stack([
            cv2.resize(t, (510, 510), interpolation=cv2.INTER_CUBIC)
            for t in cf_responses
        ],
                                axis=0)

        #cf_responses[:self.cfg.scale_num // 2] *= self.cfg.scale_penalty
        #cf_responses[self.cfg.scale_num // 2 + 1:] *= self.cfg.scale_penalty
        #cf_scale_id = np.argmax(np.amax(cf_responses, axis=(1, 2)))

        #cf_response = cf_responses[cf_scale_id]
        #cf_loc = np.unravel_index(cf_response.argmax(), cf_response.shape)
        #print(cf_loc)
        #cf_disp_in_response = np.array(cf_loc) - 255 // 2

        #cf_disp_in_image = cf_disp_in_response * self.x_sz * \
        #self.scale_factors[cf_scale_id] / self.cfg.instance_sz
        #print(cf_disp_in_image)
        cf_responses = cf_responses[:, 119:391, 119:391]
        #cv2.imshow("tset", cf_responses[1])
        #cv2.waitKey(1000)
        #-------------------------------------------------------------------------
        #siamfc-------------------------------------------------------------------
        responses = np.stack([
            cv2.resize(t, (self.upscale_sz, self.upscale_sz),
                       interpolation=cv2.INTER_CUBIC) for t in responses
        ],
                             axis=0)
        responses[:self.cfg.scale_num // 2] *= self.cfg.scale_penalty
        responses[self.cfg.scale_num // 2 + 1:] *= self.cfg.scale_penalty

        # peak scale
        scale_id = np.argmax(np.amax(responses, axis=(1, 2)))

        # peak location
        response = responses[scale_id]
        cf_response = cf_responses[scale_id]
        cf_response -= cf_response.min()
        cf_response /= cf_response.sum()
        response -= response.min()
        response /= response.sum() + 1e-16
        #response = (1 - self.cfg.window_influence) * response + \
        #self.cfg.window_influence * self.hann_window
        response = (1 - self.cfg.window_influence) * response + \
            self.cfg.window_influence * self.hann_window +  self.cf_influence * cf_response
        loc = np.unravel_index(response.argmax(), response.shape)

        # locate target center
        disp_in_response = np.array(loc) - self.upscale_sz // 2
        disp_in_instance = disp_in_response * \
            self.cfg.total_stride / self.cfg.response_up
        disp_in_image = disp_in_instance * self.x_sz * \
            self.scale_factors[scale_id] / self.cfg.instance_sz
        #--------------------------------------------------------------------------

        self.center += disp_in_image
        #self.center += cf_disp_in_image
        # update target size
        scale =  (1 - self.cfg.scale_lr) * 1.0 + \
            self.cfg.scale_lr * self.scale_factors[scale_id]
        self.target_sz *= scale
        self.z_sz *= scale
        self.x_sz *= scale
        # update cf
        exemplar_image_cf = self._crop_and_resize(
            image,
            self.center,
            self.x_sz,
            out_size=self.cfg.instance_sz,
            pad_color=self.avg_color)
        exemplar_image_cf = Image_to_Tensor(exemplar_image_cf).to(
            self.device).unsqueeze(0)
        self.net.update(exemplar_image_cf, lr=0.01)
        # return 1-indexed and left-top based bounding box
        box = np.array([
            self.center[1] + 1 - (self.target_sz[1] - 1) / 2,
            self.center[0] + 1 - (self.target_sz[0] - 1) / 2,
            self.target_sz[1], self.target_sz[0]
        ])

        return box

    def step(self, batch, backward=True, update_lr=False):
        self.net.is_train = True
        if backward:
            self.net.train()
            if update_lr:
                self.lr_scheduler.step()
        else:
            self.net.eval()

        z = batch[0].to(self.device)
        x = batch[1].to(self.device)
        label = batch[2].to(self.device)
        with torch.set_grad_enabled(backward):
            responses, out2 = self.net(z, x)
            labels, weights = self._create_labels(responses.size())
            loss1 = F.binary_cross_entropy_with_logits(responses,
                                                       labels,
                                                       weight=weights,
                                                       size_average=True)
            loss2 = F.mse_loss(out2, label, size_average=True)

            loss = loss1 + loss2 * 5
            if backward:
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

        return loss.item()

    def _crop_and_resize(self, image, center, size, out_size, pad_color):
        # convert box to corners (0-indexed)
        size = round(size)
        corners = np.concatenate((np.round(center - (size - 1) / 2),
                                  np.round(center - (size - 1) / 2) + size))
        corners = np.round(corners).astype(int)

        # pad image if necessary
        pads = np.concatenate((-corners[:2], corners[2:] - image.shape[:2]))
        npad = max(0, int(pads.max()))
        if npad > 0:
            image = cv2.copyMakeBorder(image,
                                       npad,
                                       npad,
                                       npad,
                                       npad,
                                       cv2.BORDER_CONSTANT,
                                       value=pad_color)

        # crop image patch
        corners = (corners + npad).astype(int)
        patch = image[corners[0]:corners[2], corners[1]:corners[3]]

        # resize to out_size
        patch = cv2.resize(patch, (out_size, out_size))

        return patch

    def _create_labels(self, size):
        # skip if same sized labels already created
        if hasattr(self, 'labels') and self.labels.size() == size:
            return self.labels, self.weights

        def logistic_labels(x, y, r_pos, r_neg):
            dist = np.abs(x) + np.abs(y)  # block distance
            labels = np.where(
                dist <= r_pos, np.ones_like(x),
                np.where(dist < r_neg,
                         np.ones_like(x) * 0.5, np.zeros_like(x)))
            return labels

        # distances along x- and y-axis
        n, c, h, w = size
        x = np.arange(w) - w // 2
        y = np.arange(h) - h // 2
        x, y = np.meshgrid(x, y)

        # create logistic labels
        r_pos = self.cfg.r_pos / self.cfg.total_stride
        r_neg = self.cfg.r_neg / self.cfg.total_stride
        labels = logistic_labels(x, y, r_pos, r_neg)

        # pos/neg weights
        pos_num = np.sum(labels == 1)
        neg_num = np.sum(labels == 0)
        weights = np.zeros_like(labels)
        weights[labels == 1] = 0.5 / pos_num
        weights[labels == 0] = 0.5 / neg_num
        weights *= pos_num + neg_num

        # repeat to size
        labels = labels.reshape((1, 1, h, w))
        weights = weights.reshape((1, 1, h, w))
        labels = np.tile(labels, (n, c, 1, 1))
        weights = np.tile(weights, [n, c, 1, 1])

        # convert to tensors
        self.labels = torch.from_numpy(labels).to(self.device).float()
        self.weights = torch.from_numpy(weights).to(self.device).float()

        return self.labels, self.weights
Exemple #23
0
class Experiment:
    def __init__(self,
                 *,
                 model,
                 learning_rate=0.0005,
                 embedding_dim,
                 num_iterations,
                 batch_size=128,
                 decay_rate=0.,
                 conv_out=2,
                 projection_size=10,
                 input_dropout=0.4,
                 feature_map_dropout=0.4,
                 hidden_dropout=0.4,
                 label_smoothing=0.1,
                 cuda=True):

        self.model = model
        self.learning_rate = learning_rate
        self.num_iterations = num_iterations
        self.batch_size = batch_size
        self.decay_rate = decay_rate
        self.label_smoothing = label_smoothing
        self.cuda = torch.cuda.is_available()
        self.entity_idxs, self.relation_idxs, self.scheduler = None, None, None

        # Params stored in kwargs for creating pretrained models with unique names.
        self.kwargs = {
            'embedding_dim': embedding_dim,
            'learning_rate': learning_rate,
            'batch_size': batch_size,
            'conv_out': conv_out,
            'input_dropout': input_dropout,
            'hidden_dropout': hidden_dropout,
            'projection_size': projection_size,
            'feature_map_dropout': feature_map_dropout,
            'label_smoothing': label_smoothing,
            'decay_rate': decay_rate
        }

        self.storage_path, _ = create_experiment_folder()

        self.logger = create_logger(name=self.model, p=self.storage_path)

    def get_data_idxs(self, data):
        data_idxs = [
            (self.entity_idxs[data[i][0]], self.relation_idxs[data[i][1]],
             self.entity_idxs[data[i][2]]) for i in range(len(data))
        ]
        return data_idxs

    def get_er_vocab(self, data):
        er_vocab = defaultdict(list)
        for triple in data:
            er_vocab[(triple[0], triple[1])].append(triple[2])
        return er_vocab

    def get_batch(self, er_vocab, er_vocab_pairs, idx):
        batch = er_vocab_pairs[idx:idx + self.batch_size]
        targets = np.zeros((len(batch), len(d.entities)))
        for idx, pair in enumerate(batch):
            targets[idx, er_vocab[pair]] = 1.
        targets = torch.FloatTensor(targets)
        if self.cuda:
            targets = targets.cuda()
        return np.array(batch), targets

    def evaluate(self, model, data):
        hits = []
        ranks = []
        for i in range(10):
            hits.append([])

        test_data_idxs = self.get_data_idxs(data)
        er_vocab = self.get_er_vocab(self.get_data_idxs(d.data))

        for i in range(0, len(test_data_idxs), self.batch_size):
            data_batch, _ = self.get_batch(er_vocab, test_data_idxs, i)
            e1_idx = torch.tensor(data_batch[:, 0])
            r_idx = torch.tensor(data_batch[:, 1])
            e2_idx = torch.tensor(data_batch[:, 2])
            if self.cuda:
                e1_idx = e1_idx.cuda()
                r_idx = r_idx.cuda()
                e2_idx = e2_idx.cuda()
            predictions = model.forward(e1_idx, r_idx)

            for j in range(data_batch.shape[0]):
                filt = er_vocab[(data_batch[j][0], data_batch[j][1])]
                target_value = predictions[j, e2_idx[j]].item()
                predictions[j, filt] = 0.0
                predictions[j, e2_idx[j]] = target_value

            sort_values, sort_idxs = torch.sort(predictions,
                                                dim=1,
                                                descending=True)

            sort_idxs = sort_idxs.cpu().numpy()
            for j in range(data_batch.shape[0]):
                rank = np.where(sort_idxs[j] == e2_idx[j].item())[0][0]
                ranks.append(rank + 1)

                for hits_level in range(10):
                    if rank <= hits_level:
                        hits[hits_level].append(1.0)
                    else:
                        hits[hits_level].append(0.0)

        self.logger.info('Hits @10: {0}'.format(np.mean(hits[9])))

        self.logger.info('Hits @3: {0}'.format(np.mean(hits[2])))
        self.logger.info('Hits @1: {0}'.format(np.mean(hits[0])))
        # print('Mean rank: {0}'.format(np.mean(ranks)))
        self.logger.info('Mean reciprocal rank: {0}'.format(
            np.mean(1. / np.array(ranks))))

    def train_and_eval(self, d_info):

        self.entity_idxs = {d.entities[i]: i for i in range(len(d.entities))}
        self.relation_idxs = {
            d.relations[i]: i
            for i in range(len(d.relations))
        }
        train_data_idxs = self.get_data_idxs(d.train_data)

        self.kwargs.update({
            'num_entities': len(self.entity_idxs),
            'num_relations': len(self.relation_idxs)
        })
        self.kwargs.update(d_info)

        self.logger.info("Info pertaining to dataset:{0}".format(
            d_info['dataset']))
        self.logger.info("Number of triples in training data:{0}".format(
            len(d.train_data)))
        self.logger.info("Number of triples in validation data:{0}".format(
            len(d.valid_data)))
        self.logger.info("Number of triples in testing data:{0}".format(
            len(d.test_data)))
        self.logger.info("Number of entities:{0}".format(len(
            self.entity_idxs)))
        self.logger.info("Number of relations:{0}".format(
            len(self.relation_idxs)))

        self.logger.info("HyperParameter Settings:{0}".format(self.kwargs))

        model = None
        if self.model == 'Conex':
            model = ConEx(self.kwargs)
        elif self.model == 'Distmult':
            model = DistMult(self.kwargs)
        elif self.model == 'Complex':
            model = Complex(self.kwargs)
        elif self.model == 'Conve':
            model = ConvE(self.kwargs)
        elif self.model == 'Tucker':
            model = TuckER(self.kwargs)
        elif self.model == 'HypER':
            model = HypER(self.kwargs)
        else:
            raise ValueError

        self.logger.info(model)
        if self.cuda:
            model.cuda()
        model.init()
        opt = torch.optim.Adam(model.parameters(), lr=self.learning_rate)
        if self.decay_rate:
            self.scheduler = ExponentialLR(opt, self.decay_rate)

        er_vocab = self.get_er_vocab(train_data_idxs)
        er_vocab_pairs = list(er_vocab.keys())

        self.logger.info("{0} starts training".format(model.name))
        num_param = sum([p.numel() for p in model.parameters()])
        self.logger.info("'Number of free parameters: {0}".format(num_param))

        for it in range(1, self.num_iterations + 1):
            model.train()
            losses = []
            np.random.shuffle(er_vocab_pairs)
            for j in range(0, len(er_vocab_pairs), self.batch_size):
                data_batch, targets = self.get_batch(er_vocab, er_vocab_pairs,
                                                     j)
                opt.zero_grad()
                e1_idx = torch.tensor(data_batch[:, 0])
                r_idx = torch.tensor(data_batch[:, 1])
                if self.cuda:
                    e1_idx = e1_idx.cuda()
                    r_idx = r_idx.cuda()
                predictions = model.forward(e1_idx, r_idx)
                if self.label_smoothing:
                    targets = ((1.0 - self.label_smoothing) *
                               targets) + (1.0 / targets.size(1))
                loss = model.loss(predictions, targets)
                loss.backward()
                opt.step()
                losses.append(loss.item())
            if self.decay_rate:
                self.scheduler.step()

            if it % 500 == 0:
                self.logger.info('Iteration:{0} with Average loss{1}'.format(
                    it, np.mean(losses)))
                model.eval(
                )  # Turns evaluation mode on, i.e., dropouts are turned off.
                with torch.no_grad():  # Important:
                    self.logger.info("Validation:")
                    self.evaluate(model, d.valid_data)

        with open(self.storage_path + '/settings.json',
                  'w') as file_descriptor:
            json.dump(self.kwargs, file_descriptor)

        self.logger.info("Testing:")
        self.evaluate(model, d.test_data)
        torch.save(model.state_dict(), self.storage_path + '/model.pt')
                'reconstructions':      wandb.Image(recons, caption = 'reconstructions'),
                'hard reconstructions': wandb.Image(hard_recons, caption = 'hard reconstructions'),
                'codebook_indices':     wandb.Histogram(codes),
                'temperature':          temp
            }

            save_model(f'./vae.pt')
            wandb.save('./vae.pt')

            # temperature anneal

            temp = max(temp * math.exp(-ANNEAL_RATE * global_step), TEMP_MIN)

            # lr decay

            sched.step()

        if i % 10 == 0:
            lr = sched.get_last_lr()[0]
            print(epoch, i, f'lr - {lr:6f} loss - {loss.item()}')

            logs = {
                **logs,
                'epoch': epoch,
                'iter': i,
                'loss': loss.item(),
                'lr': lr
            }

        wandb.log(logs)
        global_step += 1
Exemple #25
0
def main():
    global args, best_loss
    global writer, csv_writer
    global device, kwargs

    base_model = load_model(arch=args.arch, pretrained=True)

    if args.use_parallel:
        model = FineTuneModelPool(base_model, args.arch, args.num_classes,
                                  str(args.classifier_config))
        model = torch.nn.DataParallel(model).to(device)
    else:
        model = FineTuneModelPool(base_model, args.arch, args.num_classes,
                                  str(args.classifier_config)).to(device)

    if args.use_parallel:
        params = model.module.parameters()
        mean = model.module.mean
        std = model.module.std
    else:
        params = model.parameters()
        mean = model.mean
        std = model.std

    if args.optimizer.startswith('adam'):
        optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, params),
            # Only finetunable params
            lr=args.lr)
    elif args.optimizer.startswith('rmsprop'):
        optimizer = torch.optim.RMSprop(
            filter(lambda p: p.requires_grad, params),
            # Only finetunable params
            lr=args.lr)
    elif args.optimizer.startswith('sgd'):
        optimizer = torch.optim.SGD(
            filter(lambda p: p.requires_grad, params),
            # Only finetunable params
            lr=args.lr)
    else:
        raise ValueError('Optimizer not supported')

    # optionally resume from a checkpoint
    loaded_from_checkpoint = False
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_loss = checkpoint['best_loss']
            if args.use_parallel:
                model.load_state_dict(checkpoint['state_dict'])
            else:
                model.load_state_dict(checkpoint['state_dict'])
            # optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))

            loaded_from_checkpoint = True
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if args.predict:
        pass
    elif args.evaluate:
        pass
    else:
        base_dset_kwargs = {
            'mode': 'train',
            'random_state': args.seed,
            'fold': args.fold,
            'size_ratio': args.size_ratio,
            'preprocessing_type': args.preprocessing_type,
            'fixed_size': (224, 224),
            'prob': 0.2,
            'mean': mean,
            'std': std,
        }

        train_dataset, val_dataset, train_sampler, val_sampler, train_loader, val_loader = get_datasets(
            base_dset_kwargs, args.batch_size)

        if args.multi_class:
            criterion = MultiClassBCELoss().to(device)
        else:
            criterion = nn.CrossEntropyLoss().cuda().to(device)

        hard_dice_05 = HardDice(threshold=0.5)

        if args.lr_regime == 'auto_decay':
            print('Auto-decay')
            scheduler = ExponentialLR(optimizer=optimizer,
                                      gamma=0.99,
                                      last_epoch=-1)
        elif args.lr_regime == 'plateau_decay':
            print('Plateau decay')
            scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                          mode='min',
                                          factor=0.5,
                                          patience=5,
                                          verbose=True)
        elif args.lr_regime == 'clr':
            print('CLR decay')
            scheduler = CyclicLR(optimizer=optimizer,
                                 base_lr=1e-4,
                                 max_lr=1e-2,
                                 step_size=1200,
                                 mode='exp_range',
                                 gamma=0.95)

        for epoch in range(args.start_epoch, args.epochs):
            # grow only if param specified
            if args.epochs_grow_size > 0:
                # grow each number of epochs
                if (epoch + 1) % args.epochs_grow_size == 0:
                    # do not grow over
                    if train_dataset.size_ratio < 1.0:
                        # increase the current size ration by a factor of 2
                        # also decrease the batch size by 4
                        new_batch_size = int(train_loader.batch_size // 4)
                        train_dataset, val_dataset, train_sampler, val_sampler, train_loader, val_loader = get_datasets(
                            {
                                **base_dset_kwargs,
                                **{
                                    'size_ratio': train_dataset.size_ratio * 2
                                }
                            }, new_batch_size)

            # train for one epoch
            train_loss, train_hard_dice_05, train_f1, train_acc1, train_acc5 = train(
                train_loader, model, criterion, hard_dice_05, optimizer, epoch,
                scheduler)

            # evaluate on validation set
            val_loss, val_hard_dice_05, val_f1, val_acc1, val_acc5 = validate(
                val_loader, model, criterion, hard_dice_05)

            if args.lr_regime == 'auto_decay':
                scheduler.step()
            elif args.lr_regime == 'plateau_decay':
                scheduler.step(val_loss)

            #============ TensorBoard logging ============#
            # Log the scalar values
            if args.tensorboard:
                writer.add_scalars('epoch/epoch_losses', {
                    'train_loss': train_loss,
                    'val_loss': val_loss
                }, epoch + 1)
                if train_hard_dice_05 and val_hard_dice_05:
                    writer.add_scalars(
                        'epoch/epoch_hdice05', {
                            'train_hdice': train_hard_dice_05,
                            'val_hdice': val_hard_dice_05
                        }, epoch + 1)
                if train_acc1 and val_acc1:
                    writer.add_scalars('epoch/epoch_acc1', {
                        'train_acc1': train_acc1,
                        'val_acc1': val_acc1
                    }, epoch + 1)
                else:
                    train_acc1 = 0
                    val_acc1 = 0

                if train_acc5 and val_acc5:
                    writer.add_scalars('epoch/epoch_acc5', {
                        'train_acc5': train_acc5,
                        'val_acc5': val_acc5
                    }, epoch + 1)
                else:
                    train_acc5 = 0
                    val_acc5 = 0

                if train_f1 and val_f1:
                    writer.add_scalars('epoch/epoch_f1', {
                        'train_f1': train_f1,
                        'val_f1': val_f1
                    }, epoch + 1)
                else:
                    train_f1 = 0
                    val_f1 = 0

            csv_writer.write({
                'epoch': epoch + 1,
                'train_acc1': train_acc1,
                'val_acc1': val_acc1,
                'train_acc5': train_acc5,
                'val_acc5': val_acc5,
                'train_f1': train_f1,
                'val_f1': val_f1
            })

            # remember best prec@1 and save checkpoint
            is_best = val_loss < best_loss
            best_loss = min(val_loss, best_loss)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'optimizer': optimizer.state_dict(),
                    'state_dict': model.state_dict(),
                    'best_loss': best_loss,
                }, is_best,
                'weights/{}_checkpoint.pth.tar'.format(str(args.lognumber)),
                'weights/{}_best.pth.tar'.format(str(args.lognumber)))
Exemple #26
0
    def train_and_eval(self):
        logger.info(
            f'Training the {model_name} model with {dataset} knowledge graph ...'
        )
        self.entity_idxs = {d.entities[i]: i for i in range(len(d.entities))}
        self.relation_idxs = {
            d.relations[i]: i
            for i in range(len(d.relations))
        }
        train_triple_idxs = self.get_data_idxs(d.train_data)
        train_triple_size = len(train_triple_idxs)
        logger.info(f'Number of training data points: {train_triple_size}')

        if model_name.lower() == "hype":
            model = HypE(d, self.ent_vec_dim, self.rel_vec_dim, **self.kwargs)
        elif model_name.lower() == "hyper":
            model = HypER(d, self.ent_vec_dim, self.rel_vec_dim, **self.kwargs)
        elif model_name.lower() == "distmult":
            model = DistMult(d, self.ent_vec_dim, self.rel_vec_dim,
                             **self.kwargs)
        elif model_name.lower() == "conve":
            model = ConvE(d, self.ent_vec_dim, self.rel_vec_dim, **self.kwargs)
        elif model_name.lower() == "complex":
            model = ComplEx(d, self.ent_vec_dim, self.rel_vec_dim,
                            **self.kwargs)
        logger.debug('model parameters: {}'.format(
            {name: value.numel()
             for name, value in model.named_parameters()}))

        if self.cuda:
            model.cuda()

        model.init()
        opt = torch.optim.Adam(model.parameters(), lr=self.learning_rate)

        if self.decay_rate:
            scheduler = ExponentialLR(opt, self.decay_rate)

        er_vocab = self.get_er_vocab(train_triple_idxs)
        er_vocab_pairs = list(er_vocab.keys())
        er_vocab_pairs_size = len(er_vocab_pairs)
        logger.info(
            f'Number of entity-relational pairs: {er_vocab_pairs_size}')

        logger.info('Starting Training ...')

        for epoch in range(1, self.epochs + 1):
            logger.info(f'Epoch: {epoch}')

            model.train()
            costs = []
            np.random.shuffle(er_vocab_pairs)

            for j in range(0, er_vocab_pairs_size, self.batch_size):
                if j % (128 * 100) == 0:
                    logger.info(f'Batch: {j + 1} ...')

                triples, targets = self.get_batch(er_vocab, er_vocab_pairs,
                                                  er_vocab_pairs_size, j)
                opt.zero_grad()
                e1_idx = torch.tensor(triples[:, 0])
                r_idx = torch.tensor(triples[:, 1])

                if self.cuda:
                    e1_idx = e1_idx.cuda()
                    r_idx = r_idx.cuda()

                predictions = model.forward(e1_idx, r_idx)

                if self.label_smoothing:
                    targets = ((1.0 - self.label_smoothing) *
                               targets) + (1.0 / targets.size(1))

                cost = model.loss(predictions, targets)
                cost.backward()
                opt.step()

                costs.append(cost.item())

            if self.decay_rate:
                scheduler.step()

            logger.info(f'Mean training cost: {np.mean(costs)}')

            if epoch % 10 == 0:
                model.eval()
                with torch.no_grad():
                    train_data = np.array(d.train_data)
                    train_data_map = {
                        'WN18': 10000,
                        'FB15k': 100000,
                        'WN18RR': 6068,
                        'FB15k-237': 35070
                    }
                    train_data_sample_size = train_data_map[dataset]
                    train_data = train_data[
                        np.random.choice(train_data.shape[0],
                                         train_data_sample_size,
                                         replace=False), :]

                    logger.info(f'Starting Evaluation: Training ...')
                    self.evaluate(model, train_data, epoch, 'training')
                    logger.info(f'Evaluation: Training complete!')
                    logger.info(f'Starting Evaluation: Validation ...')
                    self.evaluate(model, d.valid_data, epoch, 'validation')
                    logger.info(f'Evaluation: Validation complete!')
                    logger.info(f'Starting Evaluation: Test ...')
                    self.evaluate(model, d.test_data, epoch, 'testing')
                    logger.info(f'Evaluation: Test complete!')

                    logger.info('Checkpointing model ...')
                    torch.save(model.state_dict(), 'HypER.mc')
                    logger.info('Model checkpoint complete!')

            logger.info('Saving final model ...')
            torch.save(model.state_dict(), 'HypER.pt')
            logger.info('Saving final model complete!')
Exemple #27
0
def train(data_path,
          entity_path,
          relation_path,
          entity_dict,
          relation_dict,
          neg_batch_size,
          batch_size,
          shuffle,
          num_workers,
          nb_epochs,
          embedding_dim,
          hidden_dim,
          relation_dim,
          gpu,
          use_cuda,
          patience,
          freeze,
          validate_every,
          num_hops,
          lr,
          entdrop,
          reldrop,
          scoredrop,
          l3_reg,
          model_name,
          decay,
          ls,
          w_matrix,
          bn_list,
          valid_data_path=None):
    entities = np.load(entity_path)
    relations = np.load(relation_path)
    e, r = preprocess_entities_relations(entity_dict, relation_dict, entities,
                                         relations)
    entity2idx, idx2entity, embedding_matrix = prepare_embeddings(e)
    data = process_text_file(data_path, split=False)
    # data = pickle.load(open(data_path, 'rb'))
    word2ix, idx2word, max_len = get_vocab(data)
    hops = str(num_hops)
    # print(idx2word)
    # aditay
    # print(idx2word.keys())
    device = torch.device(gpu if use_cuda else "cpu")
    dataset = DatasetMetaQA(data=data,
                            word2ix=word2ix,
                            relations=r,
                            entities=e,
                            entity2idx=entity2idx)
    data_loader = DataLoaderMetaQA(dataset,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=num_workers)
    model = RelationExtractor(embedding_dim=embedding_dim,
                              hidden_dim=hidden_dim,
                              vocab_size=len(word2ix),
                              num_entities=len(idx2entity),
                              relation_dim=relation_dim,
                              pretrained_embeddings=embedding_matrix,
                              freeze=freeze,
                              device=device,
                              entdrop=entdrop,
                              reldrop=reldrop,
                              scoredrop=scoredrop,
                              l3_reg=l3_reg,
                              model=model_name,
                              ls=ls,
                              w_matrix=w_matrix,
                              bn_list=bn_list)
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = ExponentialLR(optimizer, decay)
    optimizer.zero_grad()
    best_score = -float("inf")
    best_model = model.state_dict()
    no_update = 0
    for epoch in range(nb_epochs):
        phases = []
        for i in range(validate_every):
            phases.append('train')
        phases.append('valid')
        for phase in phases:
            if phase == 'train':
                model.train()
                if freeze == True:
                    # print('Freezing batch norm layers')
                    model.apply(set_bn_eval)
                loader = tqdm(data_loader,
                              total=len(data_loader),
                              unit="batches")
                running_loss = 0
                for i_batch, a in enumerate(loader):
                    model.zero_grad()
                    question = a[0].to(device)
                    sent_len = a[1].to(device)
                    positive_head = a[2].to(device)
                    positive_tail = a[3].to(device)

                    loss = model(sentence=question,
                                 p_head=positive_head,
                                 p_tail=positive_tail,
                                 question_len=sent_len)
                    loss.backward()
                    optimizer.step()
                    running_loss += loss.item()
                    loader.set_postfix(Loss=running_loss /
                                       ((i_batch + 1) * batch_size),
                                       Epoch=epoch)
                    loader.set_description('{}/{}'.format(epoch, nb_epochs))
                    loader.update()

                scheduler.step()

            elif phase == 'valid':
                model.eval()
                eps = 0.0001
                answers, score = validate(model=model,
                                          data_path=valid_data_path,
                                          word2idx=word2ix,
                                          entity2idx=entity2idx,
                                          device=device,
                                          model_name=model_name)
                if score > best_score + eps:
                    best_score = score
                    no_update = 0
                    best_model = model.state_dict()
                    print(
                        hops +
                        " hop Validation accuracy increased from previous epoch",
                        score)
                    _, test_score = validate(model=model,
                                             data_path=test_data_path,
                                             word2idx=word2ix,
                                             entity2idx=entity2idx,
                                             device=device,
                                             model_name=model_name)
                    print('Test score for best valid so far:', test_score)
                    # writeToFile(answers, 'results_' + model_name + '_' + hops + '.txt')
                    suffix = ''
                    if freeze == True:
                        suffix = '_frozen'
                    checkpoint_path = '../../checkpoints/MetaQA/'
                    checkpoint_file_name = checkpoint_path + model_name + '_' + num_hops + suffix + ".pt"
                    print('Saving checkpoint to ', checkpoint_file_name)
                    torch.save(model.state_dict(), checkpoint_file_name)
                elif (score < best_score + eps) and (no_update < patience):
                    no_update += 1
                    print(
                        "Validation accuracy decreases to %f from %f, %d more epoch to check"
                        % (score, best_score, patience - no_update))
                elif no_update == patience:
                    print(
                        "Model has exceed patience. Saving best model and exiting"
                    )
                    torch.save(best_model,
                               checkpoint_path + "best_score_model.pt")
                    exit()
                if epoch == nb_epochs - 1:
                    print(
                        "Final Epoch has reached. Stopping and saving model.")
                    torch.save(best_model,
                               checkpoint_path + "best_score_model.pt")
                    exit()
            optimizer.zero_grad()

            # torch to variable
            X_train_variable = Variable(torch.from_numpy(X_train).float())
            y_train_variable = Variable(torch.from_numpy(y_train).float())

            X_val_variable = Variable(torch.from_numpy(X_val).float())
            y_val_variable = Variable(torch.from_numpy(y_val).float())

            # train model
            mlp.train()
            outputs = mlp(X_train_variable)
            loss = criterion(outputs, y_train_variable)
            loss.backward()
            optimizer.step()
            scheduler.step()

            if epoch % 100 == 0:
                print('Epoch [%d/%d], Validation Loss: %.4f' %
                      (epoch + 1, num_epochs, loss.data[0]))

        # validate model
        mlp.eval()  # for the case of dropout-layer and Batch-normalization
        outputs = mlp(X_val_variable)
        validation_loss = criterion(outputs, y_val_variable)

        #print(validation_loss.data[0])
        #total_loss += validation_loss.data[0]
        #print(total_loss)
        print("End of one cross validation subset")
Exemple #29
0
def main():
    learning_rate = 0.0005
    BCEloss = nn.BCELoss().cuda()
    MSEloss = nn.MSELoss().cuda()

    generator = model.Generator()
    #generator.load_state_dict(t.load('checkpoints/4fx/epoch0_g.tar'))
    discriminator = model.Discriminator()
    #discriminator.load_state_dict(t.load('checkpoints/4fx/epoch0_d.tar'))
    classified = model.Classified()
    #classified.load_state_dict(t.load('checkpoints/4fx/epoch0_c.tar'))
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    classified = classified.cuda()

    generator.train()
    discriminator.train()
    classified.train()
    t1 = lo.NgsimDataset('data/5feature/TrainSet.mat')
    t2 = lo.NgsimDataset('data/5feature/ValSet.mat')
    trainDataloader = DataLoader(t1,
                                 batch_size=128,
                                 shuffle=True,
                                 num_workers=8,
                                 collate_fn=t1.collate_fn)  #46272batch
    valDataloader = DataLoader(t2,
                               batch_size=128,
                               shuffle=True,
                               num_workers=8,
                               collate_fn=t2.collate_fn)  #6716batch
    optimizer_g = optim.Adam(generator.parameters(), lr=learning_rate)
    optimizer_d = optim.Adam(discriminator.parameters(), lr=learning_rate)
    optimizer_c = optim.Adam(classified.parameters(), lr=learning_rate)

    scheduler_g = ExponentialLR(optimizer_g, gamma=0.6)
    scheduler_d = ExponentialLR(optimizer_d, gamma=0.5)
    scheduler_c = ExponentialLR(optimizer_c, gamma=0.6)
    file = open('./checkpoints/6f/loss.txt', 'w')

    for epoch in range(6):
        print("epoch:", epoch, 'lr', optimizer_d.param_groups[0]['lr'])
        loss_gi1 = 0
        loss_gix = 0
        loss_gi3 = 0
        loss_gi4 = 0
        for idx, data in enumerate(trainDataloader):
            hist, nbrs, mask, lat_enc, lon_enc, fut, op_mask, va, nbrsva, lane, nbrslane, dis, nbrsdis = data
            hist = hist.cuda()
            nbrs = nbrs.cuda()
            mask = mask.cuda()
            lat_enc = lat_enc.cuda()
            lon_enc = lon_enc.cuda()
            fut = fut.cuda()
            op_mask = op_mask.cuda()
            va = va.cuda()
            nbrsva = nbrsva.cuda()
            lane = lane.cuda()
            nbrslane = nbrslane.cuda()
            dis = dis.cuda()
            nbrsdis = nbrsdis.cuda()

            #C训练
            traj = t.cat((hist, fut), 0)
            c_out = classified(traj)
            loss_c = BCEloss(c_out, lat_enc)
            optimizer_c.zero_grad()
            loss_c.backward()
            a = t.nn.utils.clip_grad_norm_(classified.parameters(), 10)
            optimizer_c.step()
            #D训练
            real_data, _ = discriminator(traj, nbrs, mask)
            g_out, _, _ = generator(hist, nbrs, mask, lat_enc, lon_enc, va,
                                    nbrsva, lane, nbrslane, dis, nbrsdis)
            fake_data, _ = discriminator(t.cat((hist, g_out), 0), nbrs, mask)
            real_label = t.ones_like(real_data)
            fake_label = t.zeros_like(fake_data)
            loss_d1 = BCEloss(real_data, real_label)
            loss_d2 = BCEloss(fake_data, fake_label)
            loss_d = loss_d1 + loss_d2
            optimizer_d.zero_grad()
            loss_d.backward()
            a = t.nn.utils.clip_grad_norm_(discriminator.parameters(), 10)
            optimizer_d.step()
            #G训练
            g_out, lat_pred, lon_pred = generator(hist, nbrs, mask, lat_enc,
                                                  lon_enc, va, nbrsva, lane,
                                                  nbrslane, dis, nbrsdis)
            loss_g1 = MSELoss2(g_out, fut, op_mask)
            loss_gx = BCEloss(lat_pred, lat_enc) + BCEloss(lon_pred, lon_enc)
            traj_fake = t.cat((hist, g_out), 0)
            traj_true = t.cat((hist, fut), 0)
            c_out = classified(traj_fake)
            loss_g3 = BCEloss(c_out, lat_enc)
            _, outp1 = discriminator(traj_fake, nbrs, mask)
            _, outp2 = discriminator(traj_true, nbrs, mask)
            loss_g4 = MSEloss(outp1, outp2)
            loss_g = loss_g1 + loss_gx + 5 * loss_g3 + 5 * loss_g4
            optimizer_g.zero_grad()
            loss_g.backward()
            a = t.nn.utils.clip_grad_norm_(generator.parameters(), 10)
            optimizer_g.step()
            loss_gi1 += loss_g1.item()
            loss_gix += loss_gx.item()
            loss_gi3 += loss_g3.item()
            loss_gi4 += loss_g4.item()
            if idx % 100 == 99:
                print('mse:', loss_gi1 / 100, '|c1:', loss_gix / 100, '|c:',
                      loss_gi3 / 100, '|d:', loss_gi4 / 100)
                file.write(str(loss_gi1 / 100) + ',')
                loss_gi1 = 0
                loss_gix = 0
                loss_gi3 = 0
                loss_gi4 = 0

        avg_val_loss = 0
        val_batch_count = 0
        print('startval:')
        for i, data in enumerate(valDataloader):
            hist, nbrs, mask, lat_enc, lon_enc, fut, op_mask, va, nbrsva, lane, nbrslane, dis, nbrsdis = data

            hist = hist.cuda()
            nbrs = nbrs.cuda()
            mask = mask.cuda()
            lat_enc = lat_enc.cuda()
            lon_enc = lon_enc.cuda()
            fut = fut.cuda()
            op_mask = op_mask.cuda()
            va = va.cuda()
            nbrsva = nbrsva.cuda()
            lane = lane.cuda()
            nbrslane = nbrslane.cuda()
            dis = dis.cuda()
            nbrsdis = nbrsdis.cuda()

            fut_pred, _, _ = generator(hist, nbrs, mask, lat_enc, lon_enc, va,
                                       nbrsva, lane, nbrslane, dis, nbrsdis)
            l = MSELoss2(fut_pred, fut, op_mask)

            avg_val_loss += l.item()
            val_batch_count += 1

        print('valmse:', avg_val_loss / val_batch_count)
        t.save(generator.state_dict(),
               'checkpoints/6f/epoch' + str(epoch + 1) + '_g.tar')
        t.save(discriminator.state_dict(),
               'checkpoints/6f/epoch' + str(epoch + 1) + '_d.tar')
        t.save(classified.state_dict(),
               'checkpoints/6f/epoch' + str(epoch + 1) + '_c.tar')
        scheduler_g.step()
        scheduler_d.step()
        scheduler_c.step()
    file.close()
Exemple #30
0
    def train_and_eval(self):
        print("Training the TuckER model...")
        self.entity_idxs = {d.entities[i]: i for i in range(len(d.entities))}
        self.relation_idxs = {
            d.relations[i]: i
            for i in range(len(d.relations))
        }

        train_data_idxs = self.get_data_idxs(d.train_data)
        print("Number of training data points: %d" % len(train_data_idxs))

        model = TuckER(d, self.ent_vec_dim, self.rel_vec_dim, **self.kwargs)
        if self.cuda:
            model.cuda()
        model.init()
        opt = torch.optim.Adam(model.parameters(), lr=self.learning_rate)
        if self.decay_rate:
            scheduler = ExponentialLR(opt, self.decay_rate)

        er_vocab = self.get_er_vocab(train_data_idxs)
        er_vocab_pairs = list(er_vocab.keys())

        print("Starting training...")
        for it in range(1, self.num_iterations + 1):
            start_train = time.time()
            model.train()
            losses = []
            np.random.shuffle(er_vocab_pairs)
            for j in range(0, len(er_vocab_pairs), self.batch_size):
                data_batch, targets = self.get_batch(er_vocab, er_vocab_pairs,
                                                     j)
                opt.zero_grad()
                e1_idx = torch.tensor(data_batch[:, 0])
                r_idx = torch.tensor(data_batch[:, 1])
                if self.cuda:
                    e1_idx = e1_idx.cuda()
                    r_idx = r_idx.cuda()
                predictions = model.forward(e1_idx, r_idx)
                if self.label_smoothing:
                    targets = ((1.0 - self.label_smoothing) *
                               targets) + (1.0 / targets.size(1))
                loss = model.loss(predictions, targets)
                loss.backward()
                opt.step()
                losses.append(loss.item())
            if self.decay_rate:
                scheduler.step()
            print("Iteration %d, time taken %.3f, loss %.10f " %
                  (it, time.time() - start_train, np.mean(losses)))
            model.eval()
            with torch.no_grad():
                if not it % 50 and it <= 1500:
                    print("Validation:")
                    self.evaluate(model, d.valid_data)
                    print("Test:")
                    self.evaluate(model, d.test_data)
                if not it % 1 and it > 1500:
                    print("Validation:")
                    self.evaluate(model, d.valid_data)
                    print("Test:")
                    self.evaluate(model, d.test_data)
def run(train_batch_size, val_batch_size, epochs, lr, log_interval, channels, classes):
    # load model
    model = get_featureExtractor('dense121')(classes, len(channels))

    # load saved weights
    if args.pretrained:
        if len(args.checkpoint)>0:
            cpfile = 'models/'+args.checkpoint+'.pth'
            ep0 = epoch = int(args.checkpoint.rsplit('_',1)[1])
        else:
            cpfile = 'models/Model_pretrained_DenseNet121.pth'
            ep0 = epoch = 1
        checkpoint = torch.load(cpfile)
        model.load_state_dict(checkpoint)
    else:
        ep0 = epoch = 0

    # to gpu
    if torch.cuda.is_available():
        device = "cuda"
        model.cuda(args.gpu)
    else:
        device = "cpu"

    # augmentation
    if args.augmentation:
        aug = ImgAugTransform()
    else:
        aug = None

    # load data generator
    train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size, channels, classes, device=device)  # aug=aug

    # for parallel
    if args.distributed:
        model = nn.DataParallel(model)

    # set optimizer and loss function
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss() # 
    # exponential decreasing learning rate
    lr_scheduler = ExponentialLR(optimizer, gamma=0.95)

    # process bar
    desc = "ITERATION - loss: {:.2f}"
    pbar = tqdm(
        initial=0, leave=False, total=len(train_loader),
        desc=desc.format(0)
    )

    # initialize parameters for iterating over epochs
    burn=2 # epoch burn for 
    patience=3  # number of epochs to early stop training
    vl_track=[]   # tracking the validation loss
    save_interval=1 # number of intervals to save weights
    n_saved=5    # number of weights to keep
    tlen = len(train_loader)
    vlen = len(val_loader)
    # print(tlen)
    while epoch < epochs:
        # frozen the pretrained layers, tran the fully connected classification layer
        print(f'{epoch+1}/{epochs}')
        print(f"Learning rate: {lr}")
        if epoch == 0:
            for name, child in model.named_children():  # module.
                # pbar.log_message(name)
                if name == 'fc':
                    print(name + ' is unfrozen')
                    for param in child.parameters():
                        param.requires_grad = True
                else:
                    print(name + ' are frozen')
                    for param in child.parameters():
                        param.requires_grad = False
        if epoch == burn:
            print("Turn on all the layers")
            # for name, child in model.named_children():  # module.
            for param in model.features.parameters():
                param.requires_grad = True

        model.train()
        # start of the epoch, 
        tloss = 0
        acc = np.zeros(1)
        t0 = time()
        for i, (x, y) in enumerate(train_loader): 
            if aug:
                x = torch.from_numpy(aug(x))
            x = x.to(device)
            y = torch.tensor(y).long().to(device)
            t1 = time()
            optimizer.zero_grad()
            output = model(x)
            # one hot for Binary Cross Entropy
            # target = torch.zeros_like(output, device=device)
            # target[np.arange(x.size(0)), y] = 1
            loss = criterion(output, y)
            loss.backward()
            t2 = time()
            optimizer.step()
            t3 = time()
            tloss += loss.item() 
            acc += accuracy(output, y)
            del loss, output, y, x
            torch.cuda.empty_cache()
            if i>0 and i % log_interval == 0:
                pbar.desc = desc.format(tloss/(i+1))
                pbar.update(log_interval) 
                # print(t1-t0, t2-t1, t3-t2)
                # print(psutil.cpu_percent())
                # print(psutil.virtual_memory())  # physical memory usage
        # done epoch
        pbar.desc = desc.format(tloss/tlen)
        pbar.update(tlen%log_interval)
 
        # save checkpoints
        if (epoch+1)%save_interval==0:
            ch = ''.join([str(i) for i in channels])
            torch.save(model.state_dict(), f'models/Model_pretrained_{ch}_DenseNet121_{epoch+1}.pth')
            if (epoch+1)//save_interval>n_saved:
                try:
                    os.remove(f'models/Model_pretrained_{ch}_DenseNet121_{epoch+1-save_interval*n_saved}.pth')
                except:
                    pass

        model.eval()
        # compute loss and accuracy of validation
        vloss = 0
        vacc  = np.zeros(1)
        for i, (x, y) in enumerate(val_loader): 
            x = x.to(device)
            y = torch.tensor(y).long().to(device)
            optimizer.zero_grad()
            output = model(x)
            loss = criterion(output, y)
            vloss += loss.item() 
            vacc += accuracy(output, y)
            del loss, output, y, x
            torch.cuda.empty_cache()
        vl_track.append(vloss)

        print('Epoch {} -> Train Loss: {:.4f}, ACC: {:.2f}%'.format(epoch+1, tloss/tlen, acc[0]/tlen))
        print('Epoch {} -> Validation Loss: {:.4f}, ACC: {:.2f}%'.format(epoch+1, vloss/vlen, vacc[0]/vlen))

        # reset process bar
        pbar.desc = desc.format(tloss/(i+1))
        pbar.update(log_interval) 
        pbar.n = pbar.last_print_n = 0
        # stop training if vloss keeps increasing for patience
        if epoch-ep0>=patience and all([vl_track[-1-i]>vl_track[-2-i] for i in range(patience-1)]):
            break

        # update learning
        if epoch>=burn:
            lr_scheduler.step()
            lr = float(optimizer.param_groups[0]['lr'])

        epoch += 1

    # checkpoint ignite issue https://github.com/pytorch/ignite/pull/182
    pbar.close()