'accuracy: ' + str(acc)[0:7])

                predict_for_mAP = np.array(predict_for_mAP)
                label_for_mAP = np.array(label_for_mAP)

                MAP = mAP(predict_for_mAP, label_for_mAP, 'Lsm')
                acc = accuracy(predict_for_mAP, label_for_mAP, 'Lsm')

                print("mAP: " + str(MAP) + '  ' + 'accuracy: ' + str(acc))

                if acc > max_test_acc:
                    print('Saving')
                    max_test_acc = acc
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'max_acc': max_test_acc,
                            'epoch': epoch,
                            'step': 0,
                            'opt': optimizer.state_dict()
                        }, opt.model_path + '/' + opt.model_name + '_' +
                        str(epoch) + '_' + str(max_test_acc)[0:6])
                model.train()

                test = False
                predict_for_mAP = []
                label_for_mAP = []

                if opt.test:
                    exit()
Esempio n. 2
0
class Trainer:
    def __init__(self, logger, checkpoint, device_ids, config):
        self.config = config
        self.logger = logger
        self.device_ids = device_ids

        self.dataset, n_classes = get_dataset(config['dataset'],
                                              config['dataset_params'])

        if self.config['with_labels']:
            self.config['generator_params']['n_classes'] = n_classes
            self.config['discriminator_params']['n_classes'] = n_classes
            self.config['n_classes'] = n_classes
        else:
            self.config['generator_params']['n_classes'] = None
            self.config['discriminator_params']['n_classes'] = None

        self.restore(checkpoint)

        print("Generator...")
        print(self.generator)

        print("Discriminator...")
        print(self.discriminator)

    def restore(self, checkpoint):
        self.epoch = 0

        self.generator = DCGenerator(**self.config['generator_params'])
        self.generator = DataParallelWithCallback(self.generator,
                                                  device_ids=self.device_ids)
        self.optimizer_generator = torch.optim.Adam(
            params=self.generator.parameters(),
            lr=self.config['lr_generator'],
            betas=(self.config['b1_generator'], self.config['b2_generator']),
            weight_decay=0,
            eps=1e-8)

        self.discriminator = DCDiscriminator(
            **self.config['discriminator_params'])
        self.discriminator = DataParallelWithCallback(
            self.discriminator, device_ids=self.device_ids)
        self.optimizer_discriminator = torch.optim.Adam(
            params=self.discriminator.parameters(),
            lr=self.config['lr_discriminator'],
            betas=(self.config['b1_discriminator'],
                   self.config['b2_discriminator']),
            weight_decay=0,
            eps=1e-8)

        if checkpoint is not None:
            data = torch.load(checkpoint)
            for key, value in data:
                if key == 'epoch':
                    self.epoch = value
                else:
                    self.__dict__[key].load_state_dict(value)

        lr_lambda = lambda epoch: 1 - epoch / self.config['num_epochs']
        self.scheduler_generator = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_generator, lr_lambda, last_epoch=self.epoch - 1)
        self.scheduler_discriminator = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_discriminator, lr_lambda, last_epoch=self.epoch - 1)

    def save(self):
        state_dict = {
            'epoch': self.epoch,
            'generator': self.generator.state_dict(),
            'optimizer_generator': self.optimizer_generator.state_dict(),
            'discriminator': self.discriminator.state_dict(),
            'optimizer_discriminator':
            self.optimizer_discriminator.state_dict()
        }

        torch.save(state_dict, os.path.join(self.logger.log_dir, 'cpk.pth'))

    def train(self):
        loader = DataLoader(self.dataset,
                            batch_size=self.config['discriminator_bs'],
                            shuffle=False,
                            drop_last=True,
                            num_workers=self.config['num_workers'])
        noise = torch.zeros((max(self.config['generator_bs'],
                                 self.config['discriminator_bs']),
                             self.config['generator_params']['dim_z'])).cuda()
        if self.config['with_labels']:
            labels_fake = torch.zeros(
                max(self.config['generator_bs'],
                    self.config['discriminator_bs'])).type(
                        torch.LongTensor).cuda()
        else:
            labels_fake = None

        y_fake = None
        # Keep track of current iteration for update generator
        current_iter = 0
        loss_dict = defaultdict(lambda: 0.0)

        for self.epoch in tqdm(range(self.epoch, self.config['num_epochs'])):
            for data in tqdm(loader):
                self.generator.train()
                current_iter += 1

                images, labels_real = data
                y_real = None if not self.config['with_labels'] else labels_real

                self.optimizer_generator.zero_grad()
                self.optimizer_discriminator.zero_grad()

                z = noise.normal_()[:self.config['discriminator_bs']]
                if self.config['with_labels']:
                    y_fake = labels_fake.random_(
                        self.config['n_classes'])[:self.
                                                  config['discriminator_bs']]

                with torch.no_grad():
                    images_fake = self.generator(z, y_fake)

                logits_real = self.discriminator(images, y_real)
                logits_fake = self.discriminator(images_fake, y_fake)

                loss_fake = torch.relu(1 + logits_fake).mean()
                loss_real = torch.relu(1 - logits_real).mean()

                loss_dict['loss_fake'] += loss_fake.detach().cpu().numpy()
                loss_dict['loss_real'] += loss_real.detach().cpu().numpy()

                (loss_fake + loss_real).backward()
                self.optimizer_discriminator.step()

                if current_iter % self.config['num_discriminator_updates'] == 0:
                    self.optimizer_discriminator.zero_grad()
                    self.optimizer_generator.zero_grad()

                    z = noise.normal_()[:self.config['generator_bs']]
                    if self.config['with_labels']:
                        y_fake = labels_fake.random_(
                            self.config['n_classes'])[:self.
                                                      config['generator_bs']]

                    images_fake = self.generator(z, y_fake)
                    logits_fake = self.discriminator(images_fake, y_fake)

                    adversarial_loss = -logits_fake.mean()
                    loss_dict['adversarial_loss'] += adversarial_loss.detach(
                    ).cpu().numpy()

                    adversarial_loss.backward()
                    self.optimizer_generator.step()

            save_dict = {
                key: value / current_iter
                for key, value in loss_dict.items()
            }
            save_dict['lr'] = self.optimizer_generator.param_groups[0]['lr']

            loss_dict = defaultdict(lambda: 0.0)
            current_iter = 0

            with torch.no_grad():
                noise = noise.normal_()
                if self.config['with_labels']:
                    labels_fake = labels_fake.random_(self.config['n_classes'])
                images = self.generator(noise, labels_fake)
                self.logger.save_images(self.epoch, images)

            # if self.epoch % self.config['eval_frequency'] == 0 or self.epoch == self.config['num_epochs'] - 1:
            #     self.generator.eval()
            #
            #     if self.config['samples_evaluation'] != 0:
            #         generated = []
            #         with torch.no_grad():
            #             for i in range(self.config['samples_evaluation'] // noise.shape[0] + 1):
            #                 noise = noise.normal_()
            #                 if self.config['with_labels']:
            #                     labels_fake = labels_fake.random_(self.config['n_classes'])
            #
            #                 generated.append((127.5 * self.generator(noise, labels_fake) + 127.5).cpu().numpy())
            #
            #             generated = np.concatenate(generated)[:self.config['samples_evaluation']]
            #             self.logger.save_evaluation_images(self.epoch, generated)

            self.logger.log(self.epoch, save_dict)

            self.scheduler_generator.step()
            self.scheduler_discriminator.step()
            self.save()
Esempio n. 3
0
                        MAP = mAP(np.array(predict_for_mAP), np.array(label_for_mAP), 'Lsm')
                        acc = accuracy(np.array(predict_for_mAP), np.array(label_for_mAP), 'Lsm')
                        print(" Loss: " + str(TEST_LOSS.avg)[0:5] + '  ' + 'accuracy: ' + str(acc)[0:7])

                predict_for_mAP = np.array(predict_for_mAP)
                label_for_mAP = np.array(label_for_mAP)

                MAP = mAP(predict_for_mAP, label_for_mAP, 'Lsm')
                acc = accuracy(predict_for_mAP, label_for_mAP, 'Lsm')

                print("mAP: " + str(MAP) + '  ' + 'accuracy: ' + str(acc))

                if acc > max_test_acc:
                    print('Saving')
                    max_test_acc = acc
                    torch.save({'model': model.state_dict(), 'max_acc': max_test_acc, 'epoch': epoch, 'step': 0,
                                'opt': optimizer.state_dict()},
                               opt.model_path + '/' + opt.model_name + '_' + str(epoch) + '_' + str(max_test_acc)[0:6])
                model.train()

                test = False
                predict_for_mAP = []
                label_for_mAP = []

                if opt.test:
                    exit()

    if epoch % opt.saveInter == 0:
        print('Saving')
        torch.save({'model': model.state_dict(), 'max_acc': max_test_acc, 'epoch': epoch, 'step': 0, 'opt': optimizer.state_dict()}, opt.model_path + '/' + opt.model_name + '_' + str(epoch))
Esempio n. 4
0
class Model:
    def __init__(self,
                 hidden_dim,
                 lr,
                 hard_or_full_trip,
                 margin,
                 num_workers,
                 batch_size,
                 restore_iter,
                 total_iter,
                 save_name,
                 train_pid_num,
                 frame_num,
                 model_name,
                 train_source,
                 test_source,
                 img_size=64):

        self.save_name = save_name
        self.train_pid_num = train_pid_num
        self.train_source = train_source
        self.test_source = test_source

        self.hidden_dim = hidden_dim
        self.lr = lr
        self.hard_or_full_trip = hard_or_full_trip
        self.margin = margin
        self.frame_num = frame_num
        self.num_workers = num_workers
        self.batch_size = batch_size
        self.model_name = model_name
        self.P, self.M = batch_size

        self.restore_iter = restore_iter
        self.total_iter = total_iter

        self.img_size = img_size

        self.encoder = SetNet(self.hidden_dim).float()
        self.encoder = DataParallelWithCallback(self.encoder)
        self.triplet_loss = TripletLoss(self.P * self.M, self.hard_or_full_trip, self.margin).float()
        self.triplet_loss = DataParallelWithCallback(self.triplet_loss)
        self.encoder.cuda()
        self.triplet_loss.cuda()

        self.optimizer = optim.Adam([
            {'params': self.encoder.parameters()},
        ], lr=self.lr)

        self.hard_loss_metric = []
        self.full_loss_metric = []
        self.full_loss_num = []
        self.dist_list = []
        self.mean_dist = 0.01

        self.sample_type = 'all'

    def collate_fn(self, batch):
        batch_size = len(batch)
        feature_num = len(batch[0][0])
        seqs = [batch[i][0] for i in range(batch_size)]
        frame_sets = [batch[i][1] for i in range(batch_size)]
        view = [batch[i][2] for i in range(batch_size)]
        seq_type = [batch[i][3] for i in range(batch_size)]
        label = [batch[i][4] for i in range(batch_size)]
        batch = [seqs, view, seq_type, label, None]

        def select_frame(index):
            sample = seqs[index]
            frame_set = frame_sets[index]
            if self.sample_type == 'random':
                frame_id_list = random.choices(frame_set, k=self.frame_num)
                _ = [feature.loc[frame_id_list].values for feature in sample]
            else:
                _ = [feature.values for feature in sample]
            return _

        seqs = list(map(select_frame, range(len(seqs))))

        if self.sample_type == 'random':
            seqs = [np.asarray([seqs[i][j] for i in range(batch_size)]) for j in range(feature_num)]
        else:
            gpu_num = min(torch.cuda.device_count(), batch_size)
            batch_per_gpu = math.ceil(batch_size / gpu_num)
            batch_frames = [[
                                len(frame_sets[i])
                                for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1))
                                if i < batch_size
                                ] for _ in range(gpu_num)]
            if len(batch_frames[-1]) != batch_per_gpu:
                for _ in range(batch_per_gpu - len(batch_frames[-1])):
                    batch_frames[-1].append(0)
            max_sum_frame = np.max([np.sum(batch_frames[_]) for _ in range(gpu_num)])
            seqs = [[
                        np.concatenate([
                                           seqs[i][j]
                                           for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1))
                                           if i < batch_size
                                           ], 0) for _ in range(gpu_num)]
                    for j in range(feature_num)]
            seqs = [np.asarray([
                                   np.pad(seqs[j][_],
                                          ((0, max_sum_frame - seqs[j][_].shape[0]), (0, 0), (0, 0)),
                                          'constant',
                                          constant_values=0)
                                   for _ in range(gpu_num)])
                    for j in range(feature_num)]
            batch[4] = np.asarray(batch_frames)

        batch[0] = seqs
        return batch

    def fit(self):
        if self.restore_iter != 0:
            self.load(self.restore_iter)

        self.encoder.train()
        self.sample_type = 'random'
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr
        triplet_sampler = TripletSampler(self.train_source, self.batch_size)
        train_loader = tordata.DataLoader(
            dataset=self.train_source,
            batch_sampler=triplet_sampler,
            collate_fn=self.collate_fn,
            num_workers=self.num_workers)

        train_label_set = list(self.train_source.label_set)
        train_label_set.sort()

        _time1 = datetime.now()
        for seq, view, seq_type, label, batch_frame in train_loader:
            self.restore_iter += 1
            self.optimizer.zero_grad()

            for i in range(len(seq)):
                seq[i] = self.np2var(seq[i]).float()
            if batch_frame is not None:
                batch_frame = self.np2var(batch_frame).int()

            feature, label_prob = self.encoder(*seq, batch_frame)

            target_label = [train_label_set.index(l) for l in label]
            target_label = self.np2var(np.array(target_label)).long()

            triplet_feature = feature.permute(1, 0, 2).contiguous()
            triplet_label = target_label.unsqueeze(0).repeat(triplet_feature.size(0), 1)
            (full_loss_metric, hard_loss_metric, mean_dist, full_loss_num
             ) = self.triplet_loss(triplet_feature, triplet_label)
            if self.hard_or_full_trip == 'hard':
                loss = hard_loss_metric.mean()
            elif self.hard_or_full_trip == 'full':
                loss = full_loss_metric.mean()

            self.hard_loss_metric.append(hard_loss_metric.mean().data.cpu().numpy())
            self.full_loss_metric.append(full_loss_metric.mean().data.cpu().numpy())
            self.full_loss_num.append(full_loss_num.mean().data.cpu().numpy())
            self.dist_list.append(mean_dist.mean().data.cpu().numpy())

            if loss > 1e-9:
                loss.backward()
                self.optimizer.step()

            if self.restore_iter % 1000 == 0:
                print(datetime.now() - _time1)
                _time1 = datetime.now()

            if self.restore_iter % 100 == 0:
                self.save()
                print('iter {}:'.format(self.restore_iter), end='')
                print(', hard_loss_metric={0:.8f}'.format(np.mean(self.hard_loss_metric)), end='')
                print(', full_loss_metric={0:.8f}'.format(np.mean(self.full_loss_metric)), end='')
                print(', full_loss_num={0:.8f}'.format(np.mean(self.full_loss_num)), end='')
                self.mean_dist = np.mean(self.dist_list)
                print(', mean_dist={0:.8f}'.format(self.mean_dist), end='')
                print(', lr=%f' % self.optimizer.param_groups[0]['lr'], end='')
                print(', hard or full=%r' % self.hard_or_full_trip)
                sys.stdout.flush()
                self.hard_loss_metric = []
                self.full_loss_metric = []
                self.full_loss_num = []
                self.dist_list = []

            # Visualization using t-SNE
            # if self.restore_iter % 500 == 0:
            #     pca = TSNE(2)
            #     pca_feature = pca.fit_transform(feature.view(feature.size(0), -1).data.cpu().numpy())
            #     for i in range(self.P):
            #         plt.scatter(pca_feature[self.M * i:self.M * (i + 1), 0],
            #                     pca_feature[self.M * i:self.M * (i + 1), 1], label=label[self.M * i])
            #
            #     plt.show()

            if self.restore_iter == self.total_iter:
                break

    def ts2var(self, x):
        return autograd.Variable(x).cuda()

    def np2var(self, x):
        return self.ts2var(torch.from_numpy(x))

    def transform(self, flag, batch_size=1):
        self.encoder.eval()
        source = self.test_source if flag == 'test' else self.train_source
        self.sample_type = 'all'
        data_loader = tordata.DataLoader(
            dataset=source,
            batch_size=batch_size,
            sampler=tordata.sampler.SequentialSampler(source),
            collate_fn=self.collate_fn,
            num_workers=self.num_workers)

        feature_list = list()
        view_list = list()
        seq_type_list = list()
        label_list = list()

        for i, x in enumerate(data_loader):
            seq, view, seq_type, label, batch_frame = x
            for j in range(len(seq)):
                seq[j] = self.np2var(seq[j]).float()
            if batch_frame is not None:
                batch_frame = self.np2var(batch_frame).int()
            # print(batch_frame, np.sum(batch_frame))

            feature, _ = self.encoder(*seq, batch_frame)
            n, num_bin, _ = feature.size()
            feature_list.append(feature.view(n, -1).data.cpu().numpy())
            view_list += view
            seq_type_list += seq_type
            label_list += label

        return np.concatenate(feature_list, 0), view_list, seq_type_list, label_list

    def save(self):
        os.makedirs(osp.join('checkpoint', self.model_name), exist_ok=True)
        torch.save(self.encoder.state_dict(),
                   osp.join('checkpoint', self.model_name,
                            '{}-{:0>5}-encoder.ptm'.format(
                                self.save_name, self.restore_iter)))
        torch.save(self.optimizer.state_dict(),
                   osp.join('checkpoint', self.model_name,
                            '{}-{:0>5}-optimizer.ptm'.format(
                                self.save_name, self.restore_iter)))

    # restore_iter: iteration index of the checkpoint to load
    def load(self, restore_iter):
        self.encoder.load_state_dict(torch.load(osp.join(
            'checkpoint', self.model_name,
            '{}-{:0>5}-encoder.ptm'.format(self.save_name, restore_iter))))
        self.optimizer.load_state_dict(torch.load(osp.join(
            'checkpoint', self.model_name,
            '{}-{:0>5}-optimizer.ptm'.format(self.save_name, restore_iter))))
Esempio n. 5
0
class Trainer:
    def __init__(self, logger, checkpoint, device_ids, config):
        self.BtoA = config['cycle_loss_weight'] != 0
        self.config = config
        self.logger = logger
        self.device_ids = device_ids

        self.restore(checkpoint)

        print("Generator...")
        print(self.generatorB)

        print("Discriminator...")
        print(self.discriminatorB)

        transform = list()
        transform.append(T.Resize(config['load_size']))
        transform.append(T.RandomCrop(config['crop_size']))
        transform.append(T.ToTensor())
        transform.append(T.Normalize(mean=(0.5, 0.5, 0.5),
                                     std=(0.5, 0.5, 0.5)))
        transform = T.Compose(transform)

        self.dataset = ABDataset(config['root_dir'],
                                 partition='train',
                                 transform=transform)

    def restore(self, checkpoint):
        self.epoch = 0

        self.generatorB = Generator(**self.config['generator_params']).cuda()
        self.generatorB = DataParallelWithCallback(self.generatorB,
                                                   device_ids=self.device_ids)
        self.optimizer_generatorB = torch.optim.Adam(
            self.generatorB.parameters(),
            lr=self.config['lr_generator'],
            betas=(0.5, 0.999))

        self.discriminatorB = Discriminator(
            **self.config['discriminator_params']).cuda()
        self.discriminatorB = DataParallelWithCallback(
            self.discriminatorB, device_ids=self.device_ids)
        self.optimizer_discriminatorB = torch.optim.Adam(
            self.discriminatorB.parameters(),
            lr=self.config['lr_discriminator'],
            betas=(0.5, 0.999))

        if self.BtoA:
            self.generatorA = Generator(
                **self.config['generator_params']).cuda()
            self.generatorA = DataParallelWithCallback(
                self.generatorA, device_ids=self.device_ids)
            self.optimizer_generatorA = torch.optim.Adam(
                self.generatorA.parameters(),
                lr=self.config['lr_generator'],
                betas=(0.5, 0.999))

            self.discriminatorA = Discriminator(
                **self.config['discriminator_params']).cuda()
            self.discriminatorA = DataParallelWithCallback(
                self.discriminatorA, device_ids=self.device_ids)
            self.optimizer_discriminatorA = torch.optim.Adam(
                self.discriminatorA.parameters(),
                lr=self.config['lr_discriminator'],
                betas=(0.5, 0.999))

        if checkpoint is not None:
            data = torch.load(checkpoint)
            for key, value in data.items():
                if key == 'epoch':
                    self.epoch = value
                else:
                    self.__dict__[key].load_state_dict(value)

        lr_lambda = lambda epoch: min(
            1, 2 - 2 * epoch / self.config['num_epochs'])
        self.scheduler_generatorB = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_generatorB, lr_lambda, last_epoch=self.epoch - 1)
        self.scheduler_discriminatorB = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_discriminatorB,
            lr_lambda,
            last_epoch=self.epoch - 1)

        if self.BtoA:
            self.scheduler_generatorA = torch.optim.lr_scheduler.LambdaLR(
                self.optimizer_generatorA,
                lr_lambda,
                last_epoch=self.epoch - 1)
            self.scheduler_discriminatorA = torch.optim.lr_scheduler.LambdaLR(
                self.optimizer_discriminatorA,
                lr_lambda,
                last_epoch=self.epoch - 1)

    def save(self):
        state_dict = {
            'epoch': self.epoch,
            'generatorB': self.generatorB.state_dict(),
            'optimizer_generatorB': self.optimizer_generatorB.state_dict(),
            'discriminatorB': self.discriminatorB.state_dict(),
            'optimizer_discriminatorB':
            self.optimizer_discriminatorB.state_dict()
        }

        if self.BtoA:
            state_dict.update({
                'generatorA':
                self.generatorA.state_dict(),
                'optimizer_generatorA':
                self.optimizer_generatorA.state_dict(),
                'discriminatorA':
                self.discriminatorA.state_dict(),
                'optimizer_discriminatorA':
                self.optimizer_discriminatorA.state_dict()
            })

        torch.save(state_dict, os.path.join(self.logger.log_dir, 'cpk.pth'))

    def train(self):
        np.random.seed(0)
        loader = DataLoader(self.dataset,
                            batch_size=self.config['bs'],
                            shuffle=False,
                            drop_last=True,
                            num_workers=4)
        images_fixed = None

        for self.epoch in tqdm(range(self.epoch, self.config['num_epochs'])):
            loss_dict = defaultdict(lambda: 0.0)
            iteration_count = 1
            for inp in tqdm(loader):
                images_A = inp['A'].cuda()
                images_B = inp['B'].cuda()

                if images_fixed is None:
                    images_fixed = {'A': images_A, 'B': images_B}
                    transform_fixed = Transform(
                        images_A.shape[0], **self.config['transform_params'])

                if self.config['identity_loss_weight'] != 0:
                    images_trg = self.generatorB(images_B, source=False)
                    identity_loss = l1(images_trg, images_B)
                    identity_loss = self.config[
                        'identity_loss_weight'] * identity_loss
                    identity_loss.backward()

                    loss_dict['identity_loss_B'] += identity_loss.detach().cpu(
                    ).numpy()

                if self.config['identity_loss_weight'] != 0 and self.BtoA:
                    images_trg = self.generatorA(images_A, source=False)
                    identity_loss = l1(images_trg, images_A)
                    identity_loss = self.config[
                        'identity_loss_weight'] * identity_loss
                    identity_loss.backward()

                    loss_dict['identity_loss_A'] += identity_loss.detach().cpu(
                    ).numpy()

                generator_loss = 0
                images_generatedB = self.generatorB(images_A, source=True)
                logits = self.discriminatorB(images_generatedB)
                adversarial_loss = gan_loss_generator(
                    logits, self.config['gan_loss_type'])
                adversarial_loss = self.config[
                    'adversarial_loss_weight'] * adversarial_loss
                generator_loss += adversarial_loss
                loss_dict['adversarial_loss_B'] += adversarial_loss.detach(
                ).cpu().numpy()

                if self.BtoA:
                    images_generatedA = self.generatorA(images_B, source=True)
                    logits = self.discriminatorA(images_generatedA)
                    adversarial_loss = gan_loss_generator(
                        logits, self.config['gan_loss_type'])
                    adversarial_loss = self.config[
                        'adversarial_loss_weight'] * adversarial_loss
                    generator_loss += adversarial_loss
                    loss_dict['adversarial_loss_A'] += adversarial_loss.detach(
                    ).cpu().numpy()

                if self.config['equivariance_loss_weight_generator'] != 0:
                    transform = Transform(images_generatedB.shape[0],
                                          **self.config['transform_params'])
                    images_A_transformed = transform.transform_frame(images_A)
                    loss = corr(
                        self.generatorB(images_A_transformed, source=True),
                        transform.transform_frame(images_generatedB))
                    loss = self.config[
                        'equivariance_loss_weight_generator'] * loss
                    generator_loss += loss
                    loss_dict['equivariance_generator_B'] += loss.detach().cpu(
                    ).numpy()

                if self.config[
                        'equivariance_loss_weight_generator'] != 0 and self.BtoA:
                    transform = Transform(images_generatedA.shape[0],
                                          **self.config['transform_params'])
                    images_B_transformed = transform.transform_frame(images_B)
                    loss = corr(
                        self.generatorB(images_B_transformed, source=True),
                        transform.transform_frame(images_generatedA))
                    loss = self.config[
                        'equivariance_loss_weight_generator'] * loss
                    generator_loss += loss
                    loss_dict['equivariance_generator_A'] += loss.detach().cpu(
                    ).numpy()

                if self.BtoA and self.config[
                        'cycle_loss_weight'] != 0 and self.BtoA:
                    images_cycled = self.generatorA(images_generatedB,
                                                    source=True)
                    cycle_loss = torch.abs(images_cycled - images_A).mean()
                    cycle_loss = self.config['cycle_loss_weight'] * cycle_loss
                    generator_loss += cycle_loss
                    loss_dict['cycle_loss_B'] += cycle_loss.detach().cpu(
                    ).numpy()

                    images_cycled = self.generatorB(images_generatedA,
                                                    source=True)
                    cycle_loss = torch.abs(images_cycled - images_B).mean()
                    cycle_loss = self.config['cycle_loss_weight'] * cycle_loss
                    generator_loss += cycle_loss
                    loss_dict['cycle_loss_A'] += cycle_loss.detach().cpu(
                    ).numpy()

                generator_loss.backward()

                self.optimizer_generatorB.step()
                self.optimizer_generatorB.zero_grad()
                self.optimizer_discriminatorB.zero_grad()

                if self.BtoA:
                    self.optimizer_generatorA.step()
                    self.optimizer_generatorA.zero_grad()
                    self.optimizer_discriminatorA.zero_grad()

                logits_fake = self.discriminatorB(images_generatedB.detach())
                logits_real = self.discriminatorB(images_B)
                discriminator_loss = gan_loss_discriminator(
                    logits_real, logits_fake, self.config['gan_loss_type'])
                loss_dict['discriminator_loss_B'] += discriminator_loss.detach(
                ).cpu().numpy()

                if self.config['equivariance_loss_weight_discriminator'] != 0:
                    images_join = torch.cat(
                        [images_generatedB.detach(), images_B])
                    logits_join = torch.cat([logits_fake, logits_real])

                    transform = Transform(images_join.shape[0],
                                          **self.config['transform_params'])
                    images_transformed = transform.transform_frame(images_join)
                    loss = corr(self.discriminatorB(images_transformed),
                                transform.transform_frame(logits_join))

                    loss = self.config[
                        'equivariance_loss_weight_discriminator'] * loss
                    discriminator_loss += loss
                    loss_dict['equivariance_discriminator_B'] += loss.detach(
                    ).cpu().numpy()

                discriminator_loss.backward()

                self.optimizer_discriminatorB.step()
                self.optimizer_discriminatorB.zero_grad()
                self.optimizer_generatorB.zero_grad()

                if self.BtoA:
                    logits_fake = self.discriminatorA(
                        images_generatedA.detach())
                    logits_real = self.discriminatorA(images_A)
                    discriminator_loss = gan_loss_discriminator(
                        logits_real, logits_fake, self.config['gan_loss_type'])
                    loss_dict[
                        'discriminator_loss_B'] += discriminator_loss.detach(
                        ).cpu().numpy()

                    if self.config[
                            'equivariance_loss_weight_discriminator'] != 0:
                        images_join = torch.cat(
                            [images_generatedA.detach(), images_A])
                        logits_join = torch.cat([logits_fake, logits_real])

                        transform = Transform(
                            images_join.shape[0],
                            **self.config['transform_params'])
                        images_transformed = transform.transform_frame(
                            images_join)
                        loss = corr(self.discriminatorA(images_transformed),
                                    transform.transform_frame(logits_join))

                        loss = self.config[
                            'equivariance_loss_weight_discriminator'] * loss
                        discriminator_loss += loss
                        loss_dict[
                            'equivariance_discriminator_B'] += loss.detach(
                            ).cpu().numpy()

                    discriminator_loss.backward()
                    self.optimizer_discriminatorA.step()
                    self.optimizer_discriminatorA.zero_grad()
                    self.optimizer_generatorA.zero_grad()

                iteration_count += 1

            with torch.no_grad():
                if not self.BtoA:
                    self.generatorB.eval()
                    transformed = transform_fixed.transform_frame(
                        images_fixed['A'])
                    self.logger.save_images(
                        self.epoch, images_fixed['A'],
                        self.generatorB(images_fixed['A'], source=True),
                        transformed, self.generatorB(transformed, source=True))
                    self.generatorB.train()
                else:
                    self.generatorA.eval()
                    self.generatorB.eval()

                    images_generatedB = self.generatorB(images_fixed['A'],
                                                        source=True)
                    images_generatedA = self.generatorA(images_fixed['B'],
                                                        source=True)

                    transformed = transform_fixed.transform_frame(
                        images_fixed['A'])
                    self.logger.save_images(
                        self.epoch, images_fixed['A'], images_generatedB,
                        transformed, self.generatorB(transformed, source=True),
                        self.generatorA(images_generatedB, source=True),
                        images_fixed['B'], images_generatedA,
                        self.generatorB(images_generatedA, source=True))

                    self.generatorA.train()
                    self.generatorB.train()

            self.scheduler_generatorB.step()
            self.scheduler_discriminatorB.step()
            if self.BtoA:
                self.scheduler_generatorA.step()
                self.scheduler_discriminatorA.step()

            save_dict = {
                key: value / iteration_count
                for key, value in loss_dict.items()
            }
            save_dict['lr'] = self.optimizer_generatorB.param_groups[0]['lr']

            self.logger.log(self.epoch, save_dict)
            self.save()
Esempio n. 6
0
def main():
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True

    parser = argparse.ArgumentParser()
    parser.add_argument('--LR',
                        type=list,
                        default=[1e-4, 1e-4],
                        help='learning rate')  # start from 1e-4
    parser.add_argument('--EPOCH', type=int, default=30, help='epoch')
    parser.add_argument('--slice_num',
                        type=int,
                        default=6,
                        help='how many slices to cut')
    parser.add_argument('--batch_size',
                        type=int,
                        default=40,
                        help='batch_size')
    parser.add_argument('--frame_num',
                        type=int,
                        default=5,
                        help='how many frames in a slice')
    parser.add_argument('--model_path',
                        type=str,
                        default='/Disk1/poli/models/DeepRNN/Kinetics_res18',
                        help='model_path')
    parser.add_argument('--model_name',
                        type=str,
                        default='checkpoint',
                        help='model name')
    parser.add_argument('--video_path',
                        type=str,
                        default='/home/poli/kinetics_scaled',
                        help='video path')
    parser.add_argument('--class_num', type=int, default=400, help='class num')
    parser.add_argument('--device_id',
                        type=list,
                        default=[0, 1, 2, 3],
                        help='learning rate')
    parser.add_argument('--resume', action='store_true', help='whether resume')
    parser.add_argument('--dropout',
                        type=list,
                        default=[0.2, 0.5],
                        help='dropout')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=1e-4,
                        help='weight decay')
    parser.add_argument('--saveInter',
                        type=int,
                        default=1,
                        help='how many epoch to save once')
    parser.add_argument('--TD_rate',
                        type=float,
                        default=0.0,
                        help='propabaility of detachout')
    parser.add_argument('--img_size', type=int, default=224, help='image size')
    parser.add_argument('--syn_bn', action='store_true', help='use syn_bn')
    parser.add_argument('--logName',
                        type=str,
                        default='logs_res18',
                        help='log dir name')
    parser.add_argument('--train', action='store_true', help='train the model')
    parser.add_argument('--test', action='store_true', help='test the model')
    parser.add_argument(
        '--overlap_rate',
        type=float,
        default=0.25,
        help='the overlap rate of the overlap coherence training scheme')
    parser.add_argument('--lambdaa',
                        type=float,
                        default=0.0,
                        help='weight of the overlap coherence loss')

    opt = parser.parse_args()
    print(opt)

    torch.cuda.set_device(opt.device_id[0])

    # ######################## Module #################################
    print('Building model')
    model = actionModel(opt.class_num,
                        batch_norm=True,
                        dropout=opt.dropout,
                        TD_rate=opt.TD_rate,
                        image_size=opt.img_size,
                        syn_bn=opt.syn_bn,
                        test_scheme=3)
    print(model)
    if opt.syn_bn:
        model = DataParallelWithCallback(model,
                                         device_ids=opt.device_id).cuda()
    else:
        model = torch.nn.DataParallel(model, device_ids=opt.device_id).cuda()
    print("Channels: " + str(model.module.channels))

    # ########################Optimizer#########################
    optimizer = torch.optim.SGD([{
        'params': model.module.RNN.parameters(),
        'lr': opt.LR[0]
    }, {
        'params': model.module.ShortCut.parameters(),
        'lr': opt.LR[0]
    }, {
        'params': model.module.classifier.parameters(),
        'lr': opt.LR[1]
    }],
                                lr=opt.LR[1],
                                weight_decay=opt.weight_decay,
                                momentum=0.9)

    # ###################### Loss Function ####################################
    loss_classification_func = nn.NLLLoss(reduce=True)

    def loss_overlap_coherence_func(pre, cur):
        loss = nn.MSELoss()
        return loss(cur, pre.detach())

    # ###################### Resume ##########################################
    resume_epoch = 0
    resume_step = 0
    max_test_acc = 0

    if opt.resume or opt.test:
        print("loading model")
        checkpoint = torch.load(opt.model_path + '/' + opt.model_name,
                                map_location={
                                    'cuda:0': 'cuda:' + str(opt.device_id[0]),
                                    'cuda:1': 'cuda:' + str(opt.device_id[0]),
                                    'cuda:2': 'cuda:' + str(opt.device_id[0]),
                                    'cuda:3': 'cuda:' + str(opt.device_id[0]),
                                    'cuda:4': 'cuda:' + str(opt.device_id[0]),
                                    'cuda:5': 'cuda:' + str(opt.device_id[0]),
                                    'cuda:6': 'cuda:' + str(opt.device_id[0]),
                                    'cuda:7': 'cuda:' + str(opt.device_id[0])
                                })

        model.load_state_dict(checkpoint['model'], strict=True)
        try:
            optimizer.load_state_dict(checkpoint['opt'], strict=True)
        except:
            pass
        for group_id, param_group in enumerate(optimizer.param_groups):
            if group_id == 0:
                param_group['lr'] = opt.LR[0]
            elif group_id == 1:
                param_group['lr'] = opt.LR[0]
            elif group_id == 2:
                param_group['lr'] = opt.LR[1]
        resume_epoch = checkpoint['epoch']
        if 'step' in checkpoint:
            resume_step = checkpoint['step'] + 1
        if 'max_acc' in checkpoint:
            max_test_acc = checkpoint['max_acc']
        print('Finish Loading')
        del checkpoint
    # ###########################################################################

    # training and testing
    model.train()
    predict_for_mAP = []
    label_for_mAP = []

    print("START")

    KineticsLoader = torch.utils.data.DataLoader(
        Kinetic_train_dataset.Kinetics(video_path=opt.video_path +
                                       '/train_frames',
                                       frame_num=opt.frame_num,
                                       batch_size=opt.batch_size,
                                       img_size=opt.img_size,
                                       slice_num=opt.slice_num,
                                       overlap_rate=opt.overlap_rate),
        batch_size=1,
        shuffle=True,
        num_workers=8)
    Loader_test = torch.utils.data.DataLoader(Kinetics_test_dataset.Kinetics(
        video_path=opt.video_path + '/val_frames',
        img_size=224,
        space=5,
        split_num=8,
        lenn=60,
        num_class=opt.class_num),
                                              batch_size=64,
                                              shuffle=True,
                                              num_workers=4)
    tensorboard_writer = SummaryWriter(
        opt.logName,
        purge_step=resume_epoch * len(KineticsLoader) * opt.slice_num +
        (resume_step + resume_step) * opt.slice_num)
    test = opt.test
    for epoch in range(resume_epoch, opt.EPOCH):

        predict_for_mAP = []
        label_for_mAP = []

        for step, (x, _, overlap_frame_num,
                   action) in enumerate(KineticsLoader):  # gives batch data

            if opt.train:
                if step + resume_step >= len(KineticsLoader):
                    break
                x = x[0]
                action = action[0]
                overlap_frame_num = overlap_frame_num[0]

                c = [
                    Variable(
                        torch.from_numpy(
                            np.zeros(
                                (x.shape[1], model.module.channels[layer + 1],
                                 model.module.input_size[layer],
                                 model.module.input_size[layer]
                                 )))).cuda().float()
                    for layer in range(model.module.RNN_layer)
                ]
                for slice in range(x.shape[0]):
                    b_x = Variable(x[slice]).cuda()
                    b_action = Variable(action[slice]).cuda()

                    out, out_beforeMerge, c = model(b_x.float(),
                                                    c)  # rnn output
                    for batch in range(len(out)):
                        predict_for_mAP.append(out[batch].data.cpu().numpy())
                        label_for_mAP.append(
                            b_action[batch][-1].data.cpu().numpy())

                    # ###################### overlap coherence loss #######################################################################################
                    loss_coherence = torch.zeros(1).cuda()

                    # claculate the coherence loss with the previous clip and current clip
                    if slice != 0:
                        for b in range(out.size()[0]):
                            loss_coherence += loss_overlap_coherence_func(
                                old_overlap[b],
                                torch.exp(out_beforeMerge[
                                    b, :overlap_frame_num[slice, b, 0].int()]))
                        loss_coherence = loss_coherence / out.size()[0]

                    # record the previous clips output
                    old_overlap = []
                    for b in range(out.size()[0]):
                        old_overlap.append(
                            torch.exp(
                                out_beforeMerge[b,
                                                -overlap_frame_num[slice, b,
                                                                   0].int():]))
                    #######################################################################################################################################

                    loss_classification = loss_classification_func(
                        out, b_action[:, -1].long())

                    loss = loss_classification + opt.lambdaa * loss_coherence
                    tensorboard_writer.add_scalar(
                        'train/loss', loss,
                        epoch * len(KineticsLoader) * opt.slice_num +
                        (step + resume_step) * opt.slice_num + slice)

                    loss.backward(retain_graph=False)

                predict_for_mAP = predict_for_mAP
                label_for_mAP = label_for_mAP
                mAPs = mAP(predict_for_mAP, label_for_mAP, 'Lsm')
                acc = accuracy(predict_for_mAP, label_for_mAP, 'Lsm')
                tensorboard_writer.add_scalar(
                    'train/mAP', mAPs,
                    epoch * len(KineticsLoader) * opt.slice_num +
                    (step + resume_step) * opt.slice_num + slice)
                tensorboard_writer.add_scalar(
                    'train/acc', acc,
                    epoch * len(KineticsLoader) * opt.slice_num +
                    (step + resume_step) * opt.slice_num + slice)

                print("Epoch: " + str(epoch) + " step: " +
                      str(step + resume_step) + " Loss: " +
                      str(loss.data.cpu().numpy()) + " Loss_coherence: " +
                      str(loss_coherence.data.cpu().numpy()) + " mAP: " +
                      str(mAPs)[0:7] + " acc: " + str(acc)[0:7])

                for p in model.module.parameters():
                    p.grad.data.clamp_(min=-5, max=5)

                if step % 2 == 1:
                    optimizer.step()
                    optimizer.zero_grad()

                predict_for_mAP = []
                label_for_mAP = []

            # ################################### test ###############################
            if (step + resume_step) % 700 == 699:
                test = True

            if test:
                print('Start Test')
                TEST_LOSS = AverageMeter()
                with torch.no_grad():
                    model.eval()
                    predict_for_mAP = []
                    label_for_mAP = []
                    print("TESTING")

                    for step_test, (x, _, _, action) in tqdm(
                            enumerate(Loader_test)):  # gives batch data
                        b_x = Variable(x).cuda()
                        b_action = Variable(action).cuda()

                        c = [
                            Variable(
                                torch.from_numpy(
                                    np.zeros((len(b_x),
                                              model.module.channels[layer + 1],
                                              model.module.input_size[layer],
                                              model.module.input_size[layer]
                                              )))).cuda().float()
                            for layer in range(model.module.RNN_layer)
                        ]
                        out, _, _ = model(b_x.float(), c)  # rnn output
                        loss = loss_classification_func(
                            out, b_action[:, -1].long())
                        TEST_LOSS.update(val=loss.data.cpu().numpy())

                        for batch in range(len(out)):
                            predict_for_mAP.append(
                                out[batch].data.cpu().numpy())
                            label_for_mAP.append(
                                b_action[batch][-1].data.cpu().numpy())

                        if step_test % 50 == 0:
                            MAP = mAP(np.array(predict_for_mAP),
                                      np.array(label_for_mAP), 'Lsm')
                            acc = accuracy(np.array(predict_for_mAP),
                                           np.array(label_for_mAP), 'Lsm')
                            print(" Loss: " + str(TEST_LOSS.avg)[0:5] + '  ' +
                                  'accuracy: ' + str(acc)[0:7])

                    predict_for_mAP = np.array(predict_for_mAP)
                    label_for_mAP = np.array(label_for_mAP)

                    MAP = mAP(predict_for_mAP, label_for_mAP, 'Lsm')
                    acc = accuracy(predict_for_mAP, label_for_mAP, 'Lsm')

                    print("mAP: " + str(MAP) + '  ' + 'accuracy: ' + str(acc))

                    if acc > max_test_acc:
                        print('Saving')
                        max_test_acc = acc
                        torch.save(
                            {
                                'model': model.state_dict(),
                                'max_acc': max_test_acc,
                                'epoch': epoch,
                                'step': 0,
                                'opt': optimizer.state_dict()
                            }, opt.model_path + '/' + opt.model_name + '_' +
                            str(epoch) + '_' + str(max_test_acc)[0:6])
                    model.train()

                    test = False
                    predict_for_mAP = []
                    label_for_mAP = []

                    if opt.test:
                        exit()

        if epoch % opt.saveInter == 0:
            print('Saving')
            torch.save(
                {
                    'model': model.state_dict(),
                    'max_acc': max_test_acc,
                    'epoch': epoch,
                    'step': 0,
                    'opt': optimizer.state_dict()
                }, opt.model_path + '/' + opt.model_name + '_' + str(epoch))

        resume_step = 0
Esempio n. 7
0
print('Building model')
model = actionModel(args.class_num, batch_norm=True, dropout=[0, 0, 0])
model = DataParallelWithCallback(model, device_ids=args.device_id).cuda()


print("loading model")
checkpoint = torch.load(args.model_path + '/' + args.model_name, map_location={'cuda:1': 'cuda:' + str(args.device_id[0]),
                                                                     'cuda:2': 'cuda:' + str(args.device_id[0]),
                                                                     'cuda:3': 'cuda:' + str(args.device_id[0]),
                                                                     'cuda:4': 'cuda:' + str(args.device_id[0]),
                                                                     'cuda:5': 'cuda:' + str(args.device_id[0]),
                                                                     'cuda:6': 'cuda:' + str(args.device_id[0]),
                                                                     'cuda:7': 'cuda:' + str(args.device_id[0]),
                                                                     'cuda:0': 'cuda:' + str(args.device_id[0])})
pre_train = checkpoint['model']
model_dict = model.state_dict()
for para in pre_train:
    if para in model_dict:
        model_dict[para] = pre_train[para]
model.load_state_dict(model_dict)
print('Finish Loading')
del checkpoint, pre_train, model_dict
print("Model: " + str(args.model_name))


predict_for_mAP = []
label_for_mAP = []

print("START")

UCF101Loader_test = torch.utils.data.DataLoader(