def train_and_evaluate(model, train_dataloader, val_dataloader, optimizer,
                       loss_fn, metrics, params, model_dir, restore_file=None):
    """Train the model and evaluate every epoch.

    Args:
        model: (torch.nn.Module) the neural network
        params: (Params) hyperparameters
        model_dir: (string) directory containing config, weights and log
        restore_file: (string) - name of file to restore from (without its extension .pth.tar)
    """
    # reload weights from restore_file if specified
    if restore_file is not None:
        restore_path = os.path.join(args.model_dir, args.restore_file + '.pth.tar')
        logging.info("Restoring parameters from {}".format(restore_path))
        utils.load_checkpoint(restore_path, model, optimizer)

    best_val_acc = 0.0

    # learning rate schedulers for different models:
    if params.model_version == "resnet18":
        scheduler = StepLR(optimizer, step_size=150, gamma=0.1)
    # for cnn models, num_epoch is always < 100, so it's intentionally not using scheduler here
    elif params.model_version == "cnn":
        scheduler = StepLR(optimizer, step_size=100, gamma=0.2)

    for epoch in range(params.num_epochs):
     
        scheduler.step()
     
        # Run one epoch
        logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))

        # compute number of batches in one epoch (one full pass over the training set)
        train(model, optimizer, loss_fn, train_dataloader, metrics, params)

        # Evaluate for one epoch on validation set
        val_metrics = evaluate(model, loss_fn, val_dataloader, metrics, params)        

        val_acc = val_metrics['accuracy']
        is_best = val_acc>=best_val_acc

        # Save weights
        utils.save_checkpoint({'epoch': epoch + 1,
                               'state_dict': model.state_dict(),
                               'optim_dict' : optimizer.state_dict()},
                               is_best=is_best,
                               checkpoint=model_dir)

        # If best_eval, best_save_path
        if is_best:
            logging.info("- Found new best accuracy")
            best_val_acc = val_acc

            # Save best val metrics in a json file in the model directory
            best_json_path = os.path.join(model_dir, "metrics_val_best_weights.json")
            utils.save_dict_to_json(val_metrics, best_json_path)

        # Save latest val metrics in a json file in the model directory
        last_json_path = os.path.join(model_dir, "metrics_val_last_weights.json")
        utils.save_dict_to_json(val_metrics, last_json_path)
示例#2
0
class Reconstructor(object):
    def __init__(self, args):
        self.reconstruction_path = args.reconstruction_path
        if not os.path.exists(self.reconstruction_path):
            os.makedirs(self.reconstruction_path)

        self.beta = args.beta
        self.train_batch_size = args.train_batch_size
        self.test_batch_size = args.test_batch_size
        self.epochs = args.epochs
        self.early_stop = args.early_stop
        self.early_stop_observation_period = args.early_stop_observation_period
        self.use_scheduler = False
        self.print_training = args.print_training
        self.class_num = args.class_num
        self.disentangle_with_reparameterization = args.disentangle_with_reparameterization

        self.z_dim = args.z_dim
        self.disc_input_dim = int(self.z_dim / 2)
        self.class_idx = range(0, self.disc_input_dim)
        self.membership_idx = range(self.disc_input_dim, self.z_dim)

        self.nets = dict()

        if args.dataset in ['MNIST', 'Fashion-MNIST', 'CIFAR-10', 'SVHN']:
            if args.dataset in ['MNIST', 'Fashion-MNIST']:
                self.num_channels = 1
            elif args.dataset in ['CIFAR-10', 'SVHN']:
                self.num_channels = 3

            self.nets['encoder'] = module.VAEConvEncoder(
                self.z_dim, self.num_channels)
            self.nets['decoder'] = module.VAEConvDecoder(
                self.z_dim, self.num_channels)

        elif args.dataset in ['adult', 'location']:
            self.nets['encoder'] = module.VAEFCEncoder(args.encoder_input_dim,
                                                       self.z_dim)
            self.nets['decoder'] = module.FCDecoder(args.encoder_input_dim,
                                                    self.z_dim)

        self.discs = {
            'class_fz':
            module.ClassDiscriminator(self.z_dim, args.class_num),
            'class_cz':
            module.ClassDiscriminator(self.disc_input_dim, args.class_num),
            'class_mz':
            module.ClassDiscriminator(self.disc_input_dim, args.class_num),
            'membership_fz':
            module.MembershipDiscriminator(self.z_dim + args.class_num, 1),
            'membership_cz':
            module.MembershipDiscriminator(
                self.disc_input_dim + args.class_num, 1),
            'membership_mz':
            module.MembershipDiscriminator(
                self.disc_input_dim + args.class_num, 1),
        }

        self.recon_loss = self.get_loss_function()
        self.class_loss = nn.CrossEntropyLoss(reduction='sum')
        self.membership_loss = nn.BCEWithLogitsLoss(reduction='sum')

        # optimizer
        self.optimizer = dict()
        for net_type in self.nets:
            self.optimizer[net_type] = optim.Adam(
                self.nets[net_type].parameters(),
                lr=args.recon_lr,
                betas=(0.5, 0.999))
        self.discriminator_lr = args.disc_lr
        for disc_type in self.discs:
            self.optimizer[disc_type] = optim.Adam(
                self.discs[disc_type].parameters(),
                lr=self.discriminator_lr,
                betas=(0.5, 0.999))

        self.weights = {
            'recon': args.recon_weight,
            'class_cz': args.class_cz_weight,
            'class_mz': args.class_mz_weight,
            'membership_cz': args.membership_cz_weight,
            'membership_mz': args.membership_mz_weight,
        }

        self.scheduler_enc = StepLR(self.optimizer['encoder'],
                                    step_size=50,
                                    gamma=0.1)
        self.scheduler_dec = StepLR(self.optimizer['decoder'],
                                    step_size=50,
                                    gamma=0.1)

        # to device
        self.device = torch.device("cuda:{}".format(args.gpu_id))
        for net_type in self.nets:
            self.nets[net_type] = self.nets[net_type].to(self.device)
        for disc_type in self.discs:
            self.discs[disc_type] = self.discs[disc_type].to(self.device)

        self.disentangle = (
            self.weights['class_cz'] + self.weights['class_mz'] +
            self.weights['membership_cz'] + self.weights['membership_mz'] > 0)

        self.start_epoch = 0
        self.best_valid_loss = float("inf")
        # self.train_loss = 0
        self.early_stop_count = 0

        self.acc_dict = {
            'class_fz': 0,
            'class_cz': 0,
            'class_mz': 0,
            'membership_fz': 0,
            'membership_cz': 0,
            'membership_mz': 0,
        }
        self.best_acc_dict = {}

        if 'cuda' in str(self.device):
            cudnn.benchmark = True

        if args.resume:
            print('==> Resuming from checkpoint..')
            try:
                self.load()
            except FileNotFoundError:
                print(
                    'There is no pre-trained model; Train model from scratch')

    #########################
    # -- Base operations -- #
    #########################
    def load(self):
        # print('====> Loading checkpoint {}'.format(self.reconstruction_path))
        checkpoint = torch.load(
            os.path.join(self.reconstruction_path, 'ckpt.pth'))
        for net_type in self.nets:
            self.nets[net_type].load_state_dict(checkpoint[net_type])
        for disc_type in self.discs:
            self.discs[disc_type].load_state_dict(checkpoint[disc_type])
        self.start_epoch = checkpoint['epoch']

    def train_epoch(self, train_ref_loader, epoch):
        for net_type in self.nets:
            self.nets[net_type].train()
        for disc_type in self.discs:
            self.discs[disc_type].train()

        total = 0

        losses = {
            'MSE': 0.,
            'KLD': 0.,
            'class_fz': 0.,
            'class_cz': 0.,
            'class_mz': 0.,
            'membership_fz': 0.,
            'membership_cz': 0.,
            'membership_mz': 0.,
        }

        corrects = {
            'MSE': 0.,
            'KLD': 0.,
            'class_fz': 0.,
            'class_cz': 0.,
            'class_mz': 0.,
            'membership_fz': 0.,
            'membership_cz': 0.,
            'membership_mz': 0.,
        }

        for batch_idx, (inputs, targets, inputs_ref,
                        targets_ref) in enumerate(train_ref_loader):
            inputs, targets = inputs.to(self.device), targets.to(self.device)
            inputs_ref, targets_ref = inputs_ref.to(
                self.device), targets_ref.to(self.device)

            total += targets.size(0)

            # ---- Reconstruction (Encoder & Decoder) ---- #
            recon_loss, MSE, KLD = self.train_reconstructor(inputs)
            losses['MSE'] += MSE
            losses['KLD'] += KLD

            # ---- Class discriminators ---- #
            correct_class_fz, loss_class_fz = self.train_disc_class_fz(
                inputs, targets)
            correct_class_cz, loss_class_cz = self.train_disc_class_cz(
                inputs, targets)
            correct_class_mz, loss_class_mz = self.train_disc_class_mz(
                inputs, targets)

            corrects['class_fz'] += correct_class_fz
            corrects['class_cz'] += correct_class_cz
            corrects['class_mz'] += correct_class_mz
            losses['class_fz'] += loss_class_fz
            losses['class_cz'] += loss_class_cz
            losses['class_mz'] += loss_class_mz

            # ---- Membership discriminators ---- #
            correct_membership_fz, loss_membership_fz = self.train_disc_membership_fz(
                inputs, targets, inputs_ref, targets_ref)
            correct_membership_cz, loss_membership_cz = self.train_disc_membership_cz(
                inputs, targets, inputs_ref, targets_ref)
            correct_membership_mz, loss_membership_mz = self.train_disc_membership_mz(
                inputs, targets, inputs_ref, targets_ref)
            corrects['membership_fz'] += correct_membership_fz
            corrects['membership_cz'] += correct_membership_cz
            corrects['membership_mz'] += correct_membership_mz
            losses['membership_fz'] += loss_membership_fz
            losses['membership_cz'] += loss_membership_cz
            losses['membership_mz'] += loss_membership_mz

            if self.disentangle:
                self.disentangle_z(inputs, targets)

            # ---- Swap membership info ---- #
            z_tr = self.inference_z(inputs)
            cz_tr, mz_tr = self.split_class_membership(z_tr)

            z_re = self.inference_z(inputs_ref)
            cz_re, mz_re = self.split_class_membership(z_re)

            z_ctr_mre = torch.cat([cz_tr, mz_re])
            z_cre_mtr = torch.cat([cz_re, mz_tr])

            recon_ctr_mre = self.nets['decoder'](z_ctr_mre)
            recon_cre_mtr = self.nets['decoder'](z_cre_mtr)

        # todo : loop
        self.acc_dict['class_fz'] = corrects['class_fz'] / total
        self.acc_dict['class_cz'] = corrects['class_cz'] / total
        self.acc_dict['class_mz'] = corrects['class_mz'] / total

        self.acc_dict['membership_fz'] = corrects['membership_fz'] / (2 *
                                                                      total)
        self.acc_dict['membership_cz'] = corrects['membership_cz'] / (2 *
                                                                      total)
        self.acc_dict['membership_mz'] = corrects['membership_mz'] / (2 *
                                                                      total)

        if self.print_training:
            print(
                '\nEpoch: {:>3}, Acc) Class (fz, cz, mz) : {:.4f}, {:.4f}, {:.4f}, Membership (fz, cz, mz) : {:.4f}, {:.4f}, {:.4f}'
                .format(
                    epoch,
                    self.acc_dict['class_fz'],
                    self.acc_dict['class_cz'],
                    self.acc_dict['class_mz'],
                    self.acc_dict['membership_fz'],
                    self.acc_dict['membership_cz'],
                    self.acc_dict['membership_mz'],
                ))

            for loss_type in losses:
                losses[loss_type] = losses[loss_type] / (batch_idx + 1)
            print(
                'Losses) MSE: {:.2f}, KLD: {:.2f}, Class (fz, cz, mz): {:.2f}, {:.2f}, {:.2f}, Membership (fz, cz, mz): {:.2f}, {:.2f}, {:.2f},'
                .format(
                    losses['MSE'],
                    losses['KLD'],
                    losses['class_fz'],
                    losses['class_cz'],
                    losses['class_mz'],
                    losses['membership_fz'],
                    losses['membership_cz'],
                    losses['membership_mz'],
                ))

    def train_reconstructor(self, inputs):
        self.optimizer['encoder'].zero_grad()
        self.optimizer['decoder'].zero_grad()
        mu, logvar = self.nets['encoder'](inputs)
        z = self.reparameterize(mu, logvar)
        recons = self.nets['decoder'](z)
        recon_loss, MSE, KLD = self.recon_loss(recons, inputs, mu, logvar)
        recon_loss = self.weights['recon'] * recon_loss
        recon_loss.backward()
        self.optimizer['encoder'].step()
        self.optimizer['decoder'].step()
        return recon_loss.item(), MSE.item(), KLD.item()

    def train_disc_class_fz(self, inputs, targets):
        self.optimizer['class_fz'].zero_grad()
        z = self.inference_z(inputs)
        pred = self.discs['class_fz'](z)
        class_loss_full = self.class_loss(pred, targets)
        class_loss_full.backward()
        self.optimizer['class_fz'].step()

        _, pred_class_from_full = pred.max(1)
        return pred_class_from_full.eq(
            targets).sum().item(), class_loss_full.item()

    def train_disc_class_cz(self, inputs, targets):
        self.optimizer['class_cz'].zero_grad()
        z = self.inference_z(inputs)
        class_z, _ = self.split_class_membership(z)
        pred = self.discs['class_cz'](class_z)
        class_loss = self.class_loss(pred, targets)
        class_loss.backward()
        self.optimizer['class_cz'].step()

        _, pred_class = pred.max(1)
        return pred_class.eq(targets).sum().item(), class_loss.item()

    def train_disc_class_mz(self, inputs, targets):
        self.optimizer['class_mz'].zero_grad()
        z = self.inference_z(inputs)
        _, membership_z = self.split_class_membership(z)
        pred = self.discs['class_mz'](membership_z)
        class_loss_membership = self.class_loss(pred, targets)
        class_loss_membership.backward()
        self.optimizer['class_mz'].step()

        _, pred_class_from_membership = pred.max(1)
        return pred_class_from_membership.eq(
            targets).sum().item(), class_loss_membership.item()

    def train_disc_membership_fz(self, inputs, targets, inputs_ref,
                                 targets_ref):
        self.optimizer['membership_fz'].zero_grad()

        z = self.inference_z(inputs)
        targets_onehot = torch.zeros(
            (len(targets), self.class_num)).to(self.device)
        targets_onehot = targets_onehot.scatter_(1, targets.reshape((-1, 1)),
                                                 1)
        z = torch.cat((z, targets_onehot), dim=1)
        pred = self.discs['membership_fz'](z)
        in_loss = self.membership_loss(pred, torch.ones_like(pred))

        z_ref = self.inference_z(inputs_ref)
        targets_ref_onehot = torch.zeros(
            (len(targets_ref), self.class_num)).to(self.device)
        targets_ref_onehot = targets_ref_onehot.scatter_(
            1, targets_ref.reshape((-1, 1)), 1)
        z_ref = torch.cat((z_ref, targets_ref_onehot), dim=1)
        pred_ref = self.discs['membership_fz'](z_ref)
        out_loss = self.membership_loss(pred_ref, torch.zeros_like(pred_ref))

        membership_loss = in_loss + out_loss
        membership_loss.backward()
        self.optimizer['membership_fz'].step()

        pred = pred.cpu().detach().numpy().squeeze()
        pred_ref = pred_ref.cpu().detach().numpy().squeeze()
        pred_concat = np.concatenate((pred, pred_ref))
        inout_concat = np.concatenate(
            (np.ones_like(pred), np.zeros_like(pred_ref)))

        return np.sum(
            inout_concat == np.round(pred_concat)), membership_loss.item()

    def train_disc_membership_cz(self, inputs, targets, inputs_ref,
                                 targets_ref):
        self.optimizer['membership_cz'].zero_grad()

        z = self.inference_z(inputs)
        class_z, _ = self.split_class_membership(z)
        targets_onehot = torch.zeros(
            (len(targets), self.class_num)).to(self.device)
        targets_onehot = targets_onehot.scatter_(1, targets.reshape((-1, 1)),
                                                 1)
        class_z = torch.cat((class_z, targets_onehot), dim=1)
        pred = self.discs['membership_cz'](class_z)
        in_loss = self.membership_loss(pred, torch.ones_like(pred))

        z_ref = self.inference_z(inputs_ref)
        class_z_ref, _ = self.split_class_membership(z_ref)
        targets_ref_onehot = torch.zeros(
            (len(targets_ref), self.class_num)).to(self.device)
        targets_ref_onehot = targets_ref_onehot.scatter_(
            1, targets_ref.reshape((-1, 1)), 1)
        class_z_ref = torch.cat((class_z_ref, targets_ref_onehot), dim=1)
        pred_ref = self.discs['membership_cz'](class_z_ref)
        out_loss = self.membership_loss(pred_ref, torch.zeros_like(pred_ref))

        membership_loss = in_loss + out_loss
        membership_loss.backward()
        self.optimizer['membership_cz'].step()

        pred = pred.cpu().detach().numpy().squeeze()
        pred_ref = pred_ref.cpu().detach().numpy().squeeze()
        pred_concat = np.concatenate((pred, pred_ref))
        inout_concat = np.concatenate(
            (np.ones_like(pred), np.zeros_like(pred_ref)))

        return np.sum(
            inout_concat == np.round(pred_concat)), membership_loss.item()

    def train_disc_membership_mz(self, inputs, targets, inputs_ref,
                                 targets_ref):
        self.optimizer['membership_mz'].zero_grad()

        z = self.inference_z(inputs)
        _, membership_z = self.split_class_membership(z)
        targets_onehot = torch.zeros(
            (len(targets), self.class_num)).to(self.device)
        targets_onehot = targets_onehot.scatter_(1, targets.reshape((-1, 1)),
                                                 1)
        membership_z = torch.cat((membership_z, targets_onehot), dim=1)
        pred = self.discs['membership_mz'](membership_z)
        in_loss = self.membership_loss(pred, torch.ones_like(pred))

        z_ref = self.inference_z(inputs_ref)
        _, membership_z_ref = self.split_class_membership(z_ref)
        targets_ref_onehot = torch.zeros(
            (len(targets_ref), self.class_num)).to(self.device)
        targets_ref_onehot = targets_ref_onehot.scatter_(
            1, targets_ref.reshape((-1, 1)), 1)
        membership_z_ref = torch.cat((membership_z_ref, targets_ref_onehot),
                                     dim=1)
        pred_ref = self.discs['membership_mz'](membership_z_ref)
        out_loss = self.membership_loss(pred_ref, torch.zeros_like(pred_ref))

        membership_loss = in_loss + out_loss
        membership_loss.backward()
        self.optimizer['membership_mz'].step()

        pred = pred.cpu().detach().numpy().squeeze()
        pred_ref = pred_ref.cpu().detach().numpy().squeeze()
        pred_concat = np.concatenate((pred, pred_ref))
        inout_concat = np.concatenate(
            (np.ones_like(pred), np.zeros_like(pred_ref)))

        return np.sum(
            inout_concat == np.round(pred_concat)), membership_loss.item()

    def disentangle_z(self, inputs, targets):
        self.optimizer['encoder'].zero_grad()
        loss = 0

        z = self.inference_z(inputs)
        cz, mz = self.split_class_membership(z)
        targets_onehot = torch.zeros(
            (len(targets), self.class_num)).to(self.device)
        targets_onehot = targets_onehot.scatter_(1, targets.reshape((-1, 1)),
                                                 1)

        if self.weights['class_cz'] != 0:
            pred = self.discs['class_cz'](cz)
            loss += self.weights['class_cz'] * self.class_loss(pred, targets)

        if self.weights['class_mz'] != 0:
            pred = self.discs['class_mz'](mz)
            loss += -self.weights['class_mz'] * self.class_loss(pred, targets)

        if self.weights['membership_cz'] != 0:
            pred = self.discs['membership_cz'](torch.cat((cz, targets_onehot),
                                                         dim=1))
            # pred = self.discs['membership_cz'](cz)
            loss += -self.weights['membership_cz'] * self.membership_loss(
                pred, torch.ones_like(pred))

        if self.weights['membership_mz'] != 0:
            pred = self.discs['membership_mz'](torch.cat((mz, targets_onehot),
                                                         dim=1))
            # pred = self.discs['membership_mz'](mz)
            loss += self.weights['membership_mz'] * self.membership_loss(
                pred, torch.ones_like(pred))

        loss.backward()
        self.optimizer['encoder'].step()

    def inference(self, loader, epoch, type='valid'):
        for net_type in self.nets:
            self.nets[net_type].eval()
        for disc_type in self.discs:
            self.discs[disc_type].eval()

        loss = 0
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(loader):
                inputs, targets = inputs.to(self.device), targets.to(
                    self.device)

                mu, logvar = self.nets['encoder'](inputs)
                z = self.reparameterize(mu, logvar)

                recons = self.nets['decoder'](z)
                recon_loss, MSE, KLD = self.recon_loss(recons, inputs, mu,
                                                       logvar)
                loss += recon_loss.item()

        if type == 'valid':
            if loss < self.best_valid_loss:
                state = {
                    'best_valid_loss': loss,
                    'epoch': epoch,
                }

                for net_type in self.nets:
                    state[net_type] = self.nets[net_type].state_dict()
                for disc_type in self.discs:
                    state[disc_type] = self.discs[disc_type].state_dict()

                torch.save(state,
                           os.path.join(self.reconstruction_path, 'ckpt.pth'))
                self.best_valid_loss = loss
                self.early_stop_count = 0
                self.best_acc_dict = self.acc_dict

                np.save(os.path.join(self.reconstruction_path, 'acc.npy'),
                        self.best_acc_dict)
                vutils.save_image(recons,
                                  os.path.join(self.reconstruction_path,
                                               '{}.png'.format(epoch)),
                                  nrow=10)

            else:
                self.early_stop_count += 1
                if self.print_training:
                    print('Early stop count: {}'.format(self.early_stop_count))

            if self.early_stop_count == self.early_stop_observation_period:
                print(self.best_acc_dict)
                if self.print_training:
                    print(
                        'Early stop count == {}; Terminate training\n'.format(
                            self.early_stop_observation_period))
                self.train_flag = False

    def train(self, train_set, valid_set=None, ref_set=None):
        print('==> Start training {}'.format(self.reconstruction_path))
        self.train_flag = True
        if self.early_stop:
            valid_loader = DataLoader(valid_set,
                                      batch_size=self.train_batch_size,
                                      shuffle=True,
                                      num_workers=2)
        for epoch in range(self.start_epoch, self.start_epoch + self.epochs):
            permutated_idx = np.random.permutation(ref_set.__len__())
            ref_set = Subset(ref_set, permutated_idx)
            train_ref_set = data.DoubleDataset(train_set, ref_set)
            train_ref_loader = DataLoader(train_ref_set,
                                          batch_size=self.train_batch_size,
                                          shuffle=True,
                                          num_workers=2)
            if self.train_flag:
                self.train_epoch(train_ref_loader, epoch)
                if self.use_scheduler:
                    self.scheduler_enc.step()
                    self.scheduler_dec.step()
                if self.early_stop:
                    self.inference(valid_loader, epoch, type='valid')
            else:
                break

    def reconstruct(self, dataset_dict, reconstruction_type_list):
        try:
            self.load()
        except FileNotFoundError:
            print(
                'There is no pre-trained model; First, train a reconstructor.')
            sys.exit(1)
        self.nets['encoder'].eval()
        self.nets['decoder'].eval()

        mse_list = []
        recon_dict = dict()

        for recon_idx, reconstruction_type in enumerate(
                reconstruction_type_list):
            recon_datasets_dict = {}
            for dataset_type, dataset in dataset_dict.items():
                loader = DataLoader(dataset,
                                    batch_size=self.test_batch_size,
                                    shuffle=False,
                                    num_workers=2)
                raws = []
                recons = []
                labels = []
                with torch.no_grad():
                    for batch_idx, (inputs, targets) in enumerate(loader):
                        inputs = inputs.to(self.device)
                        mu, logvar = self.nets['encoder'](inputs)

                        z = torch.zeros_like(mu).to(self.device)

                        mu_class, mu_membership = self.split_class_membership(
                            mu)
                        logvar_class, logvar_membership = self.split_class_membership(
                            logvar)

                        if reconstruction_type == 'cb_mb':
                            z[:, self.class_idx] = mu_class
                            z[:, self.membership_idx] = mu_membership
                        elif reconstruction_type == 'cb_mz':
                            z[:, self.class_idx] = mu_class
                            z[:, self.membership_idx] = torch.zeros_like(
                                mu_membership).to(self.device)
                        elif reconstruction_type == 'cz_mb':
                            z[:,
                              self.class_idx] = torch.zeros_like(mu_class).to(
                                  self.device)
                            z[:, self.membership_idx] = mu_membership
                        elif reconstruction_type == 'cs1.2_ms0.8':  # scaling
                            z[:, self.class_idx] = mu_class * 1.2
                            z[:, self.membership_idx] = mu_membership * 0.8
                        elif reconstruction_type == 'cb_ms0.8':  # scaling
                            z[:, self.class_idx] = mu_class
                            z[:, self.membership_idx] = mu_membership * 0.8
                        elif reconstruction_type == 'cb_ms0.5':  # scaling
                            z[:, self.class_idx] = mu_class
                            z[:, self.membership_idx] = mu_membership * 0.5
                        elif reconstruction_type == 'cb_ms0.25':  # scaling
                            z[:, self.class_idx] = mu_class
                            z[:, self.membership_idx] = mu_membership * 0.25
                        elif reconstruction_type == 'cb_ms0.1':  # scaling
                            z[:, self.class_idx] = mu_class
                            z[:, self.membership_idx] = mu_membership * 0.1
                        elif reconstruction_type == 'cb_mb_n1':  # + noise
                            z[:, self.class_idx] = mu_class
                            z[:, self.
                              membership_idx] = mu_membership + torch.randn_like(
                                  mu_membership).to(self.device)
                        elif reconstruction_type == 'cb_mb_n0.5':  # + noise
                            z[:, self.class_idx] = mu_class
                            z[:, self.
                              membership_idx] = mu_membership + 0.5 * torch.randn_like(
                                  mu_membership).to(self.device)
                        elif reconstruction_type == 'cb_mb_n0.1':  # + noise
                            z[:, self.class_idx] = mu_class
                            z[:, self.
                              membership_idx] = mu_membership + 0.1 * torch.randn_like(
                                  mu_membership).to(self.device)
                        elif reconstruction_type == 'cb_mr':
                            z[:, self.class_idx] = mu_class
                            z[:, self.membership_idx] = self.reparameterize(
                                mu_membership, logvar_membership)
                        elif reconstruction_type == 'cb_ms0.5_n0.5':  # scaling
                            z[:, self.class_idx] = mu_class
                            z[:, self.
                              membership_idx] = mu_membership * 0.5 + 0.5 * torch.randn_like(
                                  mu_membership).to(self.device)
                        elif reconstruction_type == 'cb_ms0.5_n0.1':  # scaling
                            z[:, self.class_idx] = mu_class
                            z[:, self.
                              membership_idx] = mu_membership * 0.5 + 0.1 * torch.randn_like(
                                  mu_membership).to(self.device)
                        elif reconstruction_type == 'cb_ms0.8_n0.2':  # scaling
                            z[:, self.class_idx] = mu_class
                            z[:, self.
                              membership_idx] = mu_membership * 0.8 + 0.2 * torch.randn_like(
                                  mu_membership).to(self.device)
                        elif reconstruction_type == 'cb_mConstant':
                            z[:, self.class_idx] = mu_class
                            for idx in range(z.shape[0]):
                                z[idx, self.membership_idx] = mu_membership[0]
                        elif reconstruction_type == 'cb_mConstant0.8':
                            z[:, self.class_idx] = mu_class
                            mu_membership_constant = 0.8 * mu_membership[0]
                            for idx in range(z.shape[0]):
                                z[idx,
                                  self.membership_idx] = mu_membership_constant
                        elif reconstruction_type == 'cb_mInter0.8':
                            z[:, self.class_idx] = mu_class
                            mu_membership_constant = 0.2 * mu_membership[0]
                            for idx in range(z.shape[0]):
                                z[idx,
                                  self.membership_idx] = 0.8 * mu_membership[
                                      idx] + mu_membership_constant

                        elif reconstruction_type == 'cb_mAvg':
                            z[:, self.class_idx] = mu_class
                            mu_membership_constant = torch.mean(mu_membership,
                                                                dim=0)
                            for idx in range(z.shape[0]):
                                z[idx,
                                  self.membership_idx] = mu_membership_constant

                        elif reconstruction_type == 'cb_mr1.2':
                            z[:, self.class_idx] = mu_class
                            std = torch.exp(0.5 * logvar_membership)
                            eps = torch.randn_like(std)
                            z[:, self.
                              membership_idx] = mu_membership + 1.2 * std * eps

                        elif reconstruction_type == 'cb_mr2.0':
                            z[:, self.class_idx] = mu_class
                            std = torch.exp(0.5 * logvar_membership)
                            eps = torch.randn_like(std)
                            z[:, self.
                              membership_idx] = mu_membership + 2. * std * eps

                            # print(mu_membership.shape)
                            # print(mu_membership[0].shape)
                            # z[:, self.membership_idx] = mu_membership[0]
                            # print(torch.repeat_interleave(mu_membership[0], mu_membership.shape[0], 1).shape)
                            # sys.exit(1)

                        # if reconstruction_type == 'cb_mb_sb':
                        #     z[:, self.class_idx] = mu_class
                        #     z[:, self.membership_idx] = mu_membership
                        #     z[:, self.style_idx] = mu_style
                        #
                        # elif reconstruction_type == 'cb_mb_sz':
                        #     z[:, self.class_idx] = mu_class
                        #     z[:, self.membership_idx] = mu_membership
                        #     z[:, self.style_idx] = torch.zeros_like(mu_style).to(self.device)
                        #
                        # elif reconstruction_type == 'cb_mz_sb':
                        #     z[:, self.class_idx] = mu_class
                        #     z[:, self.membership_idx] = torch.zeros_like(mu_membership).to(self.device)
                        #     z[:, self.style_idx] = mu_style
                        #
                        # elif reconstruction_type == 'cb_mz_sz':
                        #     z[:, self.class_idx] = mu_class
                        #     z[:, self.membership_idx] = torch.zeros_like(mu_membership).to(self.device)
                        #     z[:, self.style_idx] = torch.zeros_like(mu_style).to(self.device)
                        #
                        # elif reconstruction_type == 'cz_mb_sb':
                        #     z[:, self.class_idx] = torch.zeros_like(mu_class).to(self.device)
                        #     z[:, self.membership_idx] = mu_membership
                        #     z[:, self.style_idx] = mu_style
                        #
                        # elif reconstruction_type == 'cz_mb_sz':
                        #     z[:, self.class_idx] = torch.zeros_like(mu_class).to(self.device)
                        #     z[:, self.membership_idx] = mu_membership
                        #     z[:, self.style_idx] = torch.zeros_like(mu_style).to(self.device)

                        #
                        # elif reconstruction_type == 'cr_mb':
                        #     z[:, self.class_idx] = self.reparameterize(mu_class, logvar_class)
                        #     z[:, self.membership_idx] = mu_membership
                        #
                        # elif reconstruction_type == 'cr_mr':
                        #     z[:, self.class_idx] = self.reparameterize(mu_class, logvar_class)
                        #     z[:, self.membership_idx] = self.reparameterize(mu_membership, logvar_membership)
                        #
                        # elif reconstruction_type == 'cb_mn':
                        #     z[:, self.class_idx] = mu_class
                        #     z[:, self.membership_idx] = torch.randn_like(mu_membership).to(self.device)

                        recons_batch = self.nets['decoder'](z).cpu()
                        labels_batch = targets

                        if len(recons) == 0:
                            raws = inputs.cpu()
                            recons = recons_batch
                            labels = labels_batch

                            if dataset_type == 'train':
                                vutils.save_image(
                                    recons,
                                    os.path.join(
                                        self.reconstruction_path,
                                        '{}.png'.format(reconstruction_type)),
                                    nrow=10)
                                recon_dict[reconstruction_type] = recons

                                if recon_idx == 0:
                                    vutils.save_image(
                                        raws,
                                        os.path.join(self.reconstruction_path,
                                                     'raw.png'),
                                        nrow=10)

                        else:
                            raws = torch.cat((raws, inputs.cpu()), axis=0)
                            recons = torch.cat((recons, recons_batch), axis=0)
                            labels = torch.cat((labels, labels_batch), axis=0)

                recon_datasets_dict[dataset_type] = {
                    'recons': recons,
                    'labels': labels,
                }

                mse_list.append(F.mse_loss(recons, raws).item())

            # todo : refactor dict to CustomDataset
            torch.save(
                recon_datasets_dict,
                os.path.join(self.reconstruction_path,
                             'recon_{}.pt'.format(reconstruction_type)))

        np.save(os.path.join(self.reconstruction_path, 'mse.npy'), mse_list)

    @staticmethod
    def reparameterize(mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)

        return mu + std * eps

    def inference_z(self, z):
        mu, logvar = self.nets['encoder'](z)
        if self.disentangle_with_reparameterization:
            return self.reparameterize(mu, logvar)
        else:
            return mu

    def split_class_membership(self, z):
        class_z = z[:, self.class_idx]
        membership_z = z[:, self.membership_idx]

        return class_z, membership_z

    def get_loss_function(self):
        def loss_function(recon_x, x, mu, logvar):
            MSE = F.mse_loss(recon_x, x, reduction='sum')
            KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()).sum()
            return MSE + self.beta * KLD, MSE, KLD

        return loss_function
def main():
    # Training settings
    batch_size = 64
    epochs = 14
    lr = 1.0
    gamma = 0.7
    log_interval = 10
    save_model = False
    # Number of processes for dataloader (work in CPU)
    num_workers = 1

    # parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    # parser.add_argument('--batch-size', type=int, default=64, metavar='N',
    #                     help='input batch size for training (default: 64)')
    # parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
    #                     help='input batch size for testing (default: 1000)')
    # parser.add_argument('--epochs', type=int, default=14, metavar='N',
    #                     help='number of epochs to train (default: 14)')
    # parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
    #                     help='learning rate (default: 1.0)')
    # parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
    #                     help='Learning rate step gamma (default: 0.7)')
    # parser.add_argument('--no-cuda', action='store_true', default=False,
    #                     help='disables CUDA training')
    # parser.add_argument('--dry-run', action='store_true', default=False,
    #                     help='quickly check a single pass')
    # parser.add_argument('--seed', type=int, default=1, metavar='S',
    #                     help='random seed (default: 1)')
    # parser.add_argument('--log-interval', type=int, default=10, metavar='N',
    #                     help='how many batches to wait before logging training status')
    # parser.add_argument('--save-model', action='store_true', default=False,
    #                     help='For Saving the current Model')
    # args = parser.parse_args()
    # use_cuda = not args.no_cuda and torch.cuda.is_available()

    # DDP Step 1: Devices and random seed are set in set_DDP_device().
    # torch.manual_seed(args.seed)

    # device = torch.device("cuda" if use_cuda else "cpu")

    # kwargs = {'batch_size': args.batch_size}
    # if use_cuda:
    #     kwargs.update({'num_workers': 1,
    #                    'pin_memory': True,
    #                    'shuffle': True},
    #                  )

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])
    dataset1 = datasets.MNIST('./data',
                              train=True,
                              download=True,
                              transform=transform)
    dataset2 = datasets.MNIST('./data', train=False, transform=transform)
    # train_loader = torch.utils.data.DataLoader(dataset1,**kwargs)
    # test_loader = torch.utils.data.DataLoader(dataset2, **kwargs)

    # DDP Step 2: Move model to devices.
    model = Net()
    move_model_to_device(model)

    # model = Net().to(device)

    # DDP Step 3: Use DDP_prepare to prepare datasets and loaders.
    model, train_loader, test_loader, train_data_sampler, test_data_sampler = DDP_prepare(
        train_dataset=dataset1,
        test_dataset=dataset2,
        num_data_processes=num_workers,
        global_batch_size=batch_size,
        # In case you have sophisticated data processing function, pass it to collate_fn (i.e., collate_fn of the DataLoader)
        collate_fn=None,
        model=model)

    optimizer = optim.Adadelta(model.parameters(), lr=lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

    start = time.perf_counter()
    for epoch in range(1, epochs + 1):
        train(log_interval, model, train_loader, train_data_sampler, optimizer,
              epoch)
        test(model, test_loader)
        scheduler.step()

    end = time.perf_counter()
    master_print("Total Training Time %.2f seconds" % (end - start))

    if save_model:
        torch.save(model.state_dict(), "mnist_cnn.pt")
示例#4
0
def main():
    # Training settings
    # Use the command line to modify the default settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=14,
                        metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr',
                        type=float,
                        default=1.0,
                        metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument(
        '--step',
        type=int,
        default=1,
        metavar='N',
        help='number of epochs between learning rate reductions (default: 1)')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.7,
                        metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')

    parser.add_argument('--evaluate',
                        action='store_true',
                        default=False,
                        help='evaluate your model on the official test set')
    parser.add_argument('--load-model', type=str, help='model file path')

    parser.add_argument('--save-model',
                        action='store_true',
                        default=True,
                        help='For Saving the current Model')

    parser.add_argument('--split-dataset',
                        action='store_true',
                        default=False,
                        help='For Creating Validation Split')

    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    # Evaluate on the official test set
    if args.evaluate:
        assert os.path.exists(args.load_model)

        # Set the test model
        model = Net().to(device)
        model.load_state_dict(torch.load(args.load_model))

        test_dataset = datasets.MNIST('../data',
                                      train=False,
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.1307, ),
                                                               (0.3081, ))
                                      ]))

        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=args.test_batch_size,
            shuffle=True,
            **kwargs)

        test(model, device, test_loader)

        train_dataset = datasets.MNIST(
            '../data',
            train=True,
            download=True,
            transform=transforms.Compose([  # Data preprocessing
                transforms.ToTensor(),  # Add data augmentation here
                transforms.Normalize((0.1307, ), (0.3081, ))
            ]))
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True)
        test(model, device, train_loader)

        return

    # Pytorch has default MNIST dataloader which loads data at each iteration
    train_dataset = datasets.MNIST(
        '../data',
        train=True,
        download=True,
        transform=transforms.Compose([  # Data preprocessing
            #transforms.RandomAffine(10, translate = (0.1, 0.1)), # Random translation
            transforms.ToTensor(),  # Add data augmentation here
            transforms.Normalize((0.1307, ), (0.3081, ))
        ]))

    # You can assign indices for training/validation or use a random subset for
    # training by using SubsetRandomSampler. Right now the train and validation
    # sets are built from the same indices - this is bad! Change it so that
    # the training and validation sets are disjoint and have the correct relative sizes.
    if args.split_dataset:
        print("Creating New Split!")
        subset_indices_train = []
        subset_indices_valid = []

        # Get list of indices based on class
        classes = {}
        for i in range(len(train_dataset)):
            y = train_dataset[i][1]
            if y not in classes:
                classes[y] = [i]
            else:
                classes[y].append(i)
        # Randomly sample 85% of each set of classes for training
        for key in classes:
            idx = classes[key]
            idx = np.random.permutation(idx)
            subset_indices_train.extend(idx[:round(0.85 * len(idx))])
            subset_indices_valid.extend(idx[round(0.85 * len(idx)):])
        np.save("train_idx.npy", subset_indices_train)
        np.save("valid_idx.npy", subset_indices_valid)
    else:
        print("Loading previous split!")
        subset_indices_train = np.load("train_idx.npy")
        subset_indices_valid = np.load("valid_idx.npy")

    for portion in [1 / 16, 1 / 8, 1 / 4, 1 / 2]:
        curr_train_indices = np.random.choice(subset_indices_train,
                                              round(portion *
                                                    len(subset_indices_train)),
                                              replace=False)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            sampler=SubsetRandomSampler(curr_train_indices))
        val_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.test_batch_size,
            sampler=SubsetRandomSampler(subset_indices_valid))

        # Load your model [fcNet, ConvNet, Net]
        model = Net().to(device)

        # Try different optimzers here [Adam, SGD, RMSprop]
        optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

        # Set your learning rate scheduler
        scheduler = StepLR(optimizer, step_size=args.step, gamma=args.gamma)

        # Training loop
        train_losses = []
        test_losses = []
        for epoch in range(1, args.epochs + 1):
            train_loss = train(args, model, device, train_loader, optimizer,
                               epoch)
            test_loss = test(model, device, val_loader)
            train_losses.append(train_loss)
            test_losses.append(test_loss)
            scheduler.step()  # learning rate scheduler

        # You may optionally save your model at each epoch here
        np.save("train_loss.npy", np.array(train_losses))
        np.save("test_loss.npy", np.array(test_losses))

        print("Final Performance!")
        print("Validation Set:")
        test(model, device, val_loader)
        print("Training Set:")
        test(model, device, train_loader)
        if args.save_model:
            torch.save(model.state_dict(),
                       "models/mnist_model_{}.pt".format(portion))
示例#5
0
datasets = imagenet.get_datasets(data_dir)

if args.model_name == 'resnet':
    model = resnet50(num_classes=num_classes)
elif args.model_name == 'senet':
    model = se_resnet50(num_classes=num_classes)
elif args.model_name == 'srmnet':
    model = srm_resnet50(num_classes=num_classes)

optimizer = optim.SGD(model.parameters(),
                      lr=0.1,
                      momentum=0.9,
                      weight_decay=1e-4)

scheduler = StepLR(optimizer, 30, 0.1)
criterion = nn.CrossEntropyLoss()

backbone = Bone(model,
                datasets,
                criterion,
                optimizer,
                scheduler=scheduler,
                scheduler_after_ep=False,
                metric_fn=utils.accuracy_metric,
                metric_increase=True,
                batch_size=batch_size,
                num_workers=num_workers,
                weights_path=f'weights/imagenet_best_{args.model_name}.pth',
                log_dir=f'logs/imagenet/{args.model_name}')
def main():
    # Trainset stats: 2072002577 items from 124950714 sessions
    print('Initializing dataloader...')
    mtrain_loader = SpotifyDataloader(
        config_fpath=args.config,
        mtrain_mode=True,
        data_sel=(0, 99965071),  # 80% 트레인
        batch_size=TR_BATCH_SZ,
        shuffle=True,
        seq_mode=True)  # seq_mode implemented

    mval_loader = SpotifyDataloader(
        config_fpath=args.config,
        mtrain_mode=True,  # True, because we use part of trainset as testset
        data_sel=(99965071, 104965071),  #(99965071, 124950714), # 20%를 테스트
        batch_size=TS_BATCH_SZ,
        shuffle=False,
        seq_mode=True)

    # Init neural net
    SM = SeqModel().cuda(GPU)
    SM_optim = torch.optim.Adam(SM.parameters(), lr=LEARNING_RATE)
    SM_scheduler = StepLR(SM_optim, step_size=1, gamma=0.8)

    # Load checkpoint
    if args.load_continue_latest is None:
        START_EPOCH = 0
    else:
        latest_fpath = max(glob.iglob(MODEL_SAVE_PATH + "check*.pth"),
                           key=os.path.getctime)
        checkpoint = torch.load(latest_fpath,
                                map_location='cuda:{}'.format(GPU))
        tqdm.write("Loading saved model from '{0:}'... loss: {1:.6f}".format(
            latest_fpath, checkpoint['loss']))
        SM.load_state_dict(checkpoint['SM_state'])
        SM_optim.load_state_dict(checkpoint['SM_opt_state'])
        SM_scheduler.load_state_dict(checkpoint['SM_sch_state'])
        START_EPOCH = checkpoint['ep']

    # Train
    for epoch in trange(START_EPOCH,
                        EPOCHS,
                        desc='epochs',
                        position=0,
                        ascii=True):
        tqdm.write('Train...')
        tr_sessions_iter = iter(mtrain_loader)
        total_corrects = 0
        total_query = 0
        total_trloss = 0
        for session in trange(len(tr_sessions_iter),
                              desc='sessions',
                              position=1,
                              ascii=True):
            SM.train()
            x, labels, y_mask, num_items, index = tr_sessions_iter.next(
            )  # FIXED 13.Dec. SEPARATE LOGS. QUERY SHOULT NOT INCLUDE LOGS

            # Sample data for 'support' and 'query': ex) 15 items = 7 sup, 8 queries...
            num_support = num_items[:, 0].detach().numpy().flatten(
            )  # If num_items was odd number, query has one more item.
            num_query = num_items[:, 1].detach().numpy().flatten()
            batch_sz = num_items.shape[0]

            # x: the first 10 items out of 20 are support items left-padded with zeros. The last 10 are queries right-padded.
            x = x.permute(0, 2, 1)  # bx70*20
            x_sup = Variable(
                torch.cat((x[:, :, :10], labels[:, :10].unsqueeze(1)),
                          1)).cuda(GPU)  # bx71(41+29+1)*10
            x_que = torch.zeros(batch_sz, 72, 20)
            x_que[:, :41, :10] = x[:, :41, :10].clone()  # fill with x_sup_log
            x_que[:, 41:70, :] = x[:, 41:, :].clone(
            )  # fill with x_sup_feat and x_que_feat
            x_que[:, 70, :10] = 1  # support marking
            x_que[:, 71, :10] = labels[:, :10]  # labels marking
            x_que = Variable(x_que).cuda(GPU)  # bx29*10

            # y
            y = labels.clone()  # bx20

            # y_mask
            y_mask_que = y_mask.clone()
            y_mask_que[:, :10] = 0

            # Forward & update
            y_hat, att = SM(x_sup, x_que)  # y_hat: b*20, att: bx10*20

            # Calcultate BCE loss
            loss = F.binary_cross_entropy_with_logits(
                input=y_hat * y_mask_que.cuda(GPU),
                target=y.cuda(GPU) * y_mask_que.cuda(GPU))
            total_trloss += loss.item()
            SM.zero_grad()
            loss.backward()
            # Gradient Clipping
            #torch.nn.utils.clip_grad_norm_(SM.parameters(), 0.5)
            SM_optim.step()

            # Decision
            y_prob = torch.sigmoid(
                y_hat * y_mask_que.cuda(GPU)).detach().cpu().numpy()  # bx20
            y_pred = (y_prob[:, 10:] > 0.5).astype(np.int)  # bx10
            y_numpy = labels[:, 10:].numpy()  # bx10
            # Acc
            total_corrects += np.sum(
                (y_pred == y_numpy) * y_mask_que[:, 10:].numpy())
            total_query += np.sum(num_query)

            # Restore GPU memory
            del loss, y_hat

            if (session + 1) % 500 == 0:
                hist_trloss.append(total_trloss / 900)
                hist_tracc.append(total_corrects / total_query)
                # Prepare display
                sample_att = att[0, (10 - num_support[0]):10,
                                 (10 - num_support[0]):(
                                     10 +
                                     num_query[0])].detach().cpu().numpy()

                sample_sup = labels[0, (
                    10 - num_support[0]):10].long().numpy().flatten()
                sample_que = y_numpy[0, :num_query[0]].astype(int)
                sample_pred = y_pred[0, :num_query[0]]
                sample_prob = y_prob[0, 10:10 + num_query[0]]

                tqdm.write(
                    np.array2string(sample_att,
                                    formatter={
                                        'float_kind':
                                        lambda sample_att: "%.2f" % sample_att
                                    }).replace('\n ', '').replace(
                                        '][', ']\n[').replace('[[', '['))
                tqdm.write("S:" + np.array2string(sample_sup) + '\n' + "Q:" +
                           np.array2string(sample_que) + '\n' + "P:" +
                           np.array2string(sample_pred) + '\n' + "prob:" +
                           np.array2string(sample_prob))
                tqdm.write(
                    "tr_session:{0:}  tr_loss:{1:.6f}  tr_acc:{2:.4f}".format(
                        session, hist_trloss[-1], hist_tracc[-1]))
                total_corrects = 0
                total_query = 0
                total_trloss = 0

            if (session + 1) % 25000 == 0:
                # Validation
                validate(mval_loader, SM, eval_mode=True, GPU=GPU)
                # Save
                torch.save(
                    {
                        'ep': epoch,
                        'sess': session,
                        'SM_state': SM.state_dict(),
                        'loss': hist_trloss[-1],
                        'hist_vacc': hist_vacc,
                        'hist_vloss': hist_vloss,
                        'hist_trloss': hist_trloss,
                        'SM_opt_state': SM_optim.state_dict(),
                        'SM_sch_state': SM_scheduler.state_dict()
                    }, MODEL_SAVE_PATH +
                    "check_{0:}_{1:}.pth".format(epoch, session))
        # Validation
        validate(mval_loader, SM, eval_mode=True, GPU=GPU)
        # Save
        torch.save(
            {
                'ep': epoch,
                'sess': session,
                'SM_state': SM.state_dict(),
                'loss': hist_trloss[-1],
                'hist_vacc': hist_vacc,
                'hist_vloss': hist_vloss,
                'hist_trloss': hist_trloss,
                'SM_opt_state': SM_optim.state_dict(),
                'SM_sch_state': SM_scheduler.state_dict()
            }, MODEL_SAVE_PATH + "check_{0:}_{1:}.pth".format(epoch, session))
        SM_scheduler.step()
示例#7
0
def train(params):
    logger = get_logger('{}.log'.format(params['task']), '{}_logger'.format(params['task']))
    logger.info('start {}'.format(params['task']))
    
    set_all_seed(params['seed'])
    
    for key, value in params.items():
        logger.info('{} : {}'.format(key, value))
    
    logger.info('loading seqs, feas and w2v embeddings ...')
    train_val_data, sub_data, embeddings, embed_size, fea_size = load_data(params['cols'], params['embed_dir'], params['seqs_file'], params['feas_file'])
    
    logger.info('embed_size : {} | fea_size : {}'.format(embed_size, fea_size))
    batch_size = params['batch_size']
    sub_dataset = SeqDataSet(sub_data['seqs'],
                             sub_data['feas'],
                             sub_data['users'], len(params['cols']), params['max_len'], 'sub')
    sub_loader = data.DataLoader(sub_dataset, batch_size * 10, shuffle=False, collate_fn=sub_dataset.collate_fn, pin_memory=True)
    
    sub = np.zeros(shape=(sub_data['num'], 20))
    sub = pd.DataFrame(sub, index=sub_data['users'])
    
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=params['seed'])
    
    for i, (train_idx, val_idx) in enumerate(skf.split(train_val_data['feas'], train_val_data['labels'])):
        logger.info('------------------------------------------{} fold------------------------------------------'.format(i))
        train_dataset = SeqDataSet(train_val_data['seqs'][train_idx],
                                   train_val_data['feas'][train_idx],
                                   train_val_data['labels'][train_idx], len(params['cols']), params['max_len'], 'train')
        train_loader = data.DataLoader(train_dataset, batch_size, shuffle=True, collate_fn=train_dataset.collate_fn, pin_memory=True)
        
        val_dataset = SeqDataSet(train_val_data['seqs'][val_idx],
                                 train_val_data['feas'][val_idx],
                                 train_val_data['labels'][val_idx], len(params['cols']), params['max_len'], 'val')
        val_loader = data.DataLoader(val_dataset, batch_size * 10, shuffle=False, collate_fn=val_dataset.collate_fn, pin_memory=True)
        
        logger.info('train samples : {} | val samples : {} | sub samples : {}'.format(len(train_idx), len(val_idx), sub_data['num']))
        logger.info('loading net ...')
        
        embed_net = embedNet(embeddings).cuda()
        net = Net(embed_size, fea_size, params['hidden_size'], params['num_layers'], params['drop_out']).cuda()
        
        #optimizer = Ranger(params=net.parameters(), lr=params['lr'])
        optimizer = optim.AdamW(params=net.parameters(), lr=params['lr'])
        scheduler = StepLR(optimizer, step_size=2, gamma=params['gamma'])
        #scheduler = CosineAnnealingLR(optimizer, T_max=params['num_epochs'])
        loss_func = CrossEntropyLabelSmooth(20, params['label_smooth'])
        #loss_func = nn.CrossEntropyLoss()
        
        earlystop = EarlyStopping(params['early_stop_round'], logger, params['task'] + str(i))
        
        for epoch in range(params['num_epochs']):
            train_loss, val_loss = 0.0, 0.0
            train_age_acc, val_age_acc = 0.0, 0.0
            train_gender_acc, val_gender_acc = 0.0, 0.0
            train_acc, val_acc = 0.0, 0.0
            
            n, m = 0, 0
            lr_now = scheduler.get_last_lr()[0]
            logger.info('--> [Epoch {:02d}/{:02d}] lr = {:.7f}'.format(epoch, params['num_epochs'], lr_now))
            
            # 训练模型
            net.train()
            for seqs, feas, lens, labels in tqdm(train_loader, desc='[Epoch {:02d}/{:02d}] Train'.format(epoch, params['num_epochs'])):
                seqs = seqs.cuda()
                feas = feas.cuda()
                lens = lens.cuda()
                labels = labels.cuda()
                
                logits = net(embed_net(seqs), feas, lens)
                loss = loss_func(logits, labels)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                
                train_loss += loss.detach() * labels.shape[0]
                train_age_acc += (logits.argmax(dim=1).detach() % 10 == labels % 10).sum()
                train_gender_acc += (logits.argmax(dim=1).detach() // 10 == labels // 10).sum()
                
                n += lens.shape[0]

            scheduler.step()
            
            train_loss = (train_loss / n).item()
            train_age_acc = (train_age_acc / n).item()
            train_gender_acc = (train_gender_acc / n).item()
            train_acc = train_age_acc + train_gender_acc

            # 预测验证集
            net.eval()
            with torch.no_grad():
                for seqs, feas, lens, labels in tqdm(val_loader, desc='[Epoch {:02d}/{:02d}]  Val '.format(epoch, params['num_epochs'])):
                    seqs = seqs.cuda()
                    feas = feas.cuda()
                    lens = lens.cuda()
                    labels = labels.cuda()

                    logits = net(embed_net(seqs), feas, lens)
                    loss = loss_func(logits, labels)

                    val_loss += loss.detach() * labels.shape[0]
                    val_age_acc += (logits.argmax(dim=1) % 10 == labels % 10).sum()
                    val_gender_acc += (logits.argmax(dim=1).detach() // 10 == labels // 10).sum()

                    m += lens.shape[0]

                val_loss = (val_loss / m).item()
                val_age_acc = (val_age_acc / m).item()
                val_gender_acc = (val_gender_acc / m).item()
                val_acc = val_age_acc + val_gender_acc
            
            logger.info('train_loss {:.5f} | train_gender_acc {:.5f} | train_age_acc {:.5f} | train_acc {:.5f} | val_loss {:.5f} | val_gender_acc {:.5f} | val_age_acc {:.5f} | val_acc {:.5f}'
                        .format(train_loss, train_gender_acc, train_age_acc, train_acc, val_loss, val_gender_acc, val_age_acc, val_acc))
            
            # 早停
            earlystop(val_loss, val_acc, net)
            if earlystop.early_stop:
                break
        
        break 
        net.load_state_dict(torch.load('{}_checkpoint.pt'.format(params['task']+str(i))))
        logger.info('predicting sub ...')
        net.eval()
        with torch.no_grad():
            for it in range(10):
                probs = []
                users = []
                for seqs, feas, lens, ids in tqdm(sub_loader, desc='predict_{}'.format(it)):
                    seqs = seqs.cuda()
                    feas = feas.cuda()
                    lens = lens.cuda()
                    
                    logits = net(embed_net(seqs), feas, lens)
                    logits = F.softmax(logits, dim=1)
                    
                    probs.append(logits)
                    users.append(ids)
                    
                probs = torch.cat(probs).cpu().numpy()
                users = torch.cat(users).numpy()
                sub += pd.DataFrame(probs, users)
            sub = sub / 10
            
    return sub
示例#8
0
def train_network(n_epochs, learning_rate, patience, folder_path, use_cuda, batch_size):

    #data_set_path = '/home/master04/Desktop/Ply_files/_out_Town01_190402_1/pc'
    #csv_path = '/home/master04/Desktop/Ply_files/_out_Town01_190402_1/Town01_190402_1.csv'
    data_set_path_train = '/home/annika_lundqvist144/ply_files/_out_Town01_190402_1/pc'
    csv_path_train = '/home/annika_lundqvist144/ply_files/_out_Town01_190402_1/Town01_190402_1.csv'
    translation, rotation = 1,1
    data_set_path_val = '/home/annika_lundqvist144/ply_files/validation_set/pc'
    csv_path_val = '/home/annika_lundqvist144/ply_files/validation_set/validation_set.csv'

    net = PointPillars(batch_size, use_cuda)
    print('=======> NETWORK NAME: =======> ', net.name())
    if use_cuda:
        net.cuda()
    #print('Are model parameters on CUDA? ', next(net.parameters()).is_cuda)
    print(' ')

    train_loader, val_loader = get_train_loader_pointpillars(batch_size, data_set_path_train, csv_path_train, data_set_path_val, csv_path_val, rotation, translation, {'num_workers': 8})

    '''# Load weights
    if load_weights:
        print('Loading parameters...')
        network_param = torch.load(load_weights_path)
        net.load_state_dict(network_param['model_state_dict'])'''

    # Print all of the hyperparameters of the training iteration:
    print("===== HYPERPARAMETERS =====")
    print("batch_size =", batch_size)
    print("epochs =", n_epochs)
    print("initial learning_rate =", learning_rate)
    print('patience:', patience)
    print("=" * 27)

    # declare variables for storing validation and training loss to return
    val_loss = []
    train_loss = []

    # initialize the early_stopping object
    early_stopping = EarlyStopping(folder_path, patience, verbose=True)

    # Get training data
    n_batches = len(train_loader)
    print('Number of batches: ', n_batches)

    # Create our loss and optimizer functions
    loss, optimizer = create_loss_and_optimizer(net, learning_rate)
    scheduler = StepLR(optimizer, step_size=100, gamma=0.5)

    # Time for printing
    training_start_time = time.time()

    # Loop for n_epochs
    for epoch in range(n_epochs):
        scheduler.step()
        params = optimizer.state_dict()['param_groups']
        print(' ')
        print('learning rate: ', params[0]['lr'])

        running_loss = 0.0
        print_every = 13  # n_batches // 10  # how many mini-batches if we want to print stats x times per epoch
        start_time = time.time()
        total_train_loss = 0

        net = net.train()
        time_epoch = time.time()
        t1_get_data = time.time()
        for i, data in enumerate(train_loader, 1):
            t2_get_data = time.time()
            #print('get data from loader: ', t2_get_data-t1_get_data)


            # The training samples contains 5 things. 1. sweep features (xp,yp,z) 2. sweep coordinates (x,y,z)
            # 3. map features (xp,yp,z) 4. map coordinates (x,y,z) 5. labels.
            sweep = data['sweep']
            sweep_coordinates = data['sweep_coordinates']
            cutout = data['cutout']
            cutout_coordinates = data['cutout_coordinates']
            labels = data['labels']


            if use_cuda:
                sweep, sweep_coordinates, cutout, cutout_coordinates, labels = sweep.cuda(async=True), \
                                                                               sweep_coordinates.cuda(async=True), \
                                                                               cutout.cuda(async=True), \
                                                                               cutout_coordinates.cuda(async=True), \
                                                                               labels.cuda(async=True)

            sweep, sweep_coordinates, cutout, cutout_coordinates, labels = Variable(sweep), Variable(sweep_coordinates), \
                                                                     Variable(cutout), Variable(cutout_coordinates), \
                                                                           Variable(labels)


            # Set the parameter gradients to zero
            optimizer.zero_grad()
            # Forward pass, backward pass, optimize
            #t1 = time.time()
            outputs = net.forward(sweep.float(), cutout.float(), sweep_coordinates.float(), cutout_coordinates.float())#, scatter)
            #t2 = time.time()
            #print('time for forward: ', t2 - t1)

            #t1 = time.time()
            loss_size = loss(outputs, labels.float())
            #t2 = time.time()
            #print('time for get loss size: ', t2 - t1)

            #t1 = time.time()
            loss_size.backward()
            #t2 = time.time()
            #print('time for backprop: ', t2-t1)

            #t1 = time.time()
            optimizer.step()
            #t2 = time.time()
            #print('update: ', t2-t1)

            # Print statistics
            running_loss += loss_size.item()
            total_train_loss += loss_size.item()

            if True:#(i + 1) % print_every == 0:
                print('Epoch [{}/{}], Batch [{}/{}], Loss: {:.4f}, Time: '
                      .format(epoch + 1, n_epochs, i, n_batches, running_loss / print_every), time.time() - time_epoch)
                running_loss = 0.0
                time_epoch = time.time()

            #t1_get_data = time.time()
            del data, sweep, cutout, labels, outputs, loss_size
            t1_get_data = time.time()
        # At the end of the epoch, do a pass on the validation set

        total_val_loss = 0
        net = net.eval()
        with torch.no_grad():
            for data in val_loader:
                sample = data['sample']
                labels = data['labels']

                # Wrap them in a Variable object
                #sample, labels = Variable(sample).to(device), Variable(labels).to(device)

                if use_cuda:
                    sample, labels = sample.cuda(), labels.cuda()
                sample, labels = Variable(sample), Variable(labels)

                # Forward pass
                val_outputs = net.forward(sample)

                val_loss_size = loss(val_outputs, labels.float())
                total_val_loss += val_loss_size.item()

                del data, sample, labels, val_outputs, val_loss_size

        print("Training loss: {:.4f}".format(total_train_loss / len(train_loader)),
              ", Validation loss: {:.4f}".format(total_val_loss / len(val_loader)),
              ", Time: {:.2f}s".format(time.time() - start_time))
        print(' ')
        # save the loss for each epoch
        train_loss.append(total_train_loss / len(train_loader))
        val_loss.append(total_val_loss / len(val_loader))

        # see if validation loss has decreased, if it has a checkpoint will be saved of the current model.
        early_stopping(epoch, total_train_loss, total_val_loss, net, optimizer)

        # If the validation has not improved in patience # of epochs the training loop will break.
        if early_stopping.early_stop:
            print("Early stopping")
            break

    print("Training finished, took {:.2f}s".format(time.time() - training_start_time))
    return train_loss, val_loss
示例#9
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch QMNIST Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=14,
                        metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr',
                        type=float,
                        default=1.0,
                        metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.7,
                        metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')

    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(datasets.QMNIST(
        'data', train=True, download=True, transform=transforms.ToTensor()),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(datasets.QMNIST(
        'data', train=False, download=True, transform=transforms.ToTensor()),
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              **kwargs)

    model = Net().to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        test(args, model, device, test_loader)
        scheduler.step()

    if args.save_model:
        torch.save(model.state_dict(), "qmnist_cnn.pt")
示例#10
0
def main():
    parser = argparse.ArgumentParser(description="Zero Shot Learning")
    parser.add_argument("-s", "--seed", type=int, default=1234)
    parser.add_argument("-b", "--batch_size", type=int, default=50)
    parser.add_argument("-e", "--epochs", type=int, default=1000)
    parser.add_argument("-t", "--test_episode", type=int, default=1000)
    parser.add_argument("-l", "--learning_rate", type=float, default=1e-4)
    parser.add_argument("-g", "--gpu", type=int, default=0)
    args = parser.parse_args()

    BATCH_SIZE = args.batch_size
    EPOCHS = args.epochs
    TEST_EPISODE = args.test_episode
    LEARNING_RATE = args.learning_rate
    GPU = args.gpu

    np.set_printoptions(threshold=np.inf)

    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # step 1: init dataset
    print("init dataset")

    dataroot = './data'
    dataset = 'CUB1_data'
    image_embedding = 'res101'
    class_embedding = 'original_att_splits'

    attribute_values = np.load("CUB_200_2011/attribute_values.npy")

    matcontent = sio.loadmat(dataroot + "/" + dataset + "/" + image_embedding +
                             ".mat")
    feature = matcontent['features'].T

    label = matcontent['labels'].astype(int).squeeze() - 1

    matcontent = sio.loadmat(dataroot + "/" + dataset + "/" + class_embedding +
                             ".mat")
    # numpy array index starts from 0, matlab starts from 1
    trainval_loc = matcontent['trainval_loc'].squeeze() - 1
    test_seen_loc = matcontent['test_seen_loc'].squeeze() - 1
    test_unseen_loc = matcontent['test_unseen_loc'].squeeze() - 1

    embedding = matcontent['att'].T
    all_embeddings = np.array(embedding)  # (200, 312)

    train_features = feature[trainval_loc]  # train_features
    train_attributes = attribute_values[trainval_loc]  # train_attribute_values
    train_label = label[trainval_loc].astype(int)  # train_label

    test_unseen_features = feature[test_unseen_loc]  # test_unseen_features
    test_unseen_attributes = attribute_values[
        test_unseen_loc]  # test_unseen_attributes_values
    test_unseen_label = label[test_unseen_loc].astype(int)  # test_unseen_label

    test_seen_features = feature[test_seen_loc]  #test_seen_features
    test_seen_attributes = attribute_values[
        test_seen_loc]  # test_seen_attributes_values
    test_seen_label = label[test_seen_loc].astype(int)  # test_seen_label

    test_features = np.concatenate((test_unseen_features, test_seen_features),
                                   0)
    test_attributes = np.concatenate(
        (test_unseen_attributes, test_seen_attributes), 0)
    test_label = np.concatenate((test_unseen_label, test_seen_label), 0)

    train_label_set = np.unique(train_label)
    test_unseen_label_set = np.unique(test_unseen_label)
    test_seen_label_set = np.unique(test_seen_label)
    test_label_set = np.unique(test_label)

    train_features = torch.from_numpy(train_features)  # [5646, 2048]
    train_attributes = torch.from_numpy(train_attributes)  # [5646, 312]
    train_label = torch.from_numpy(train_label).unsqueeze(1)  # [5646, 1]

    test_unseen_features = torch.from_numpy(
        test_unseen_features)  # [2967, 2048]
    test_unseen_attributes = torch.from_numpy(
        test_unseen_attributes)  # [2967, 312]
    test_unseen_label = torch.from_numpy(test_unseen_label).unsqueeze(
        1)  # [2967, 1]

    test_seen_features = torch.from_numpy(test_seen_features)  # [1764, 2048]
    test_seen_attributes = torch.from_numpy(
        test_seen_attributes)  # [1764, 312]
    test_seen_label = torch.from_numpy(test_seen_label).unsqueeze(
        1)  # [1764, 1]

    test_features = torch.from_numpy(test_features)  # [4731, 2048]
    test_attributes = torch.from_numpy(test_attributes)  # [4731, 312]
    test_label = torch.from_numpy(test_label).unsqueeze(1)  # [4731, 1]

    # init network
    print("init networks")
    attribute_network = AttributeNetwork_End_to_End(2048, 1200, 312).cuda()
    relation_network = RelationNetwork(624, 300, 100).cuda()

    train_data = IntegratedDataset(train_features, train_label,
                                   train_attributes)
    test_data = IntegratedDataset(test_features, test_label, test_attributes)
    test_unseen_data = IntegratedDataset(test_unseen_features,
                                         test_unseen_label,
                                         test_unseen_attributes)
    test_seen_data = IntegratedDataset(test_seen_features, test_seen_label,
                                       test_seen_attributes)

    mse = nn.MSELoss().cuda()
    ce = nn.CrossEntropyLoss().cuda()
    nll = torch.nn.NLLLoss(weight=torch.FloatTensor([0.1, 1.])).cuda()
    #mse = nn.BCELoss().cuda()

    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True)
    test_unseen_loader = DataLoader(test_unseen_data,
                                    batch_size=BATCH_SIZE,
                                    shuffle=True)
    test_seen_loader = DataLoader(test_seen_data,
                                  batch_size=BATCH_SIZE,
                                  shuffle=True)

    attribute_network_optim = torch.optim.Adam(attribute_network.parameters(),
                                               lr=LEARNING_RATE)
    attribute_network_scheduler = StepLR(attribute_network_optim,
                                         step_size=30000,
                                         gamma=0.5)

    relation_network_optim = torch.optim.Adam(relation_network.parameters(),
                                              lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,
                                        step_size=30000,
                                        gamma=0.5)
    print("training networks...")

    total_steps = 0
    best_loss_att = 10000

    for epoch in range(EPOCHS):
        attribute_network.train()
        relation_network.train()

        for i, (batch_features, batch_labels,
                batch_att) in enumerate(train_loader):
            attribute_network_scheduler.step(total_steps)
            relation_network_scheduler.step(total_steps)

            batch_features = batch_features.float().cuda()
            pred_embeddings = attribute_network(batch_features)

            sample_labels = np.unique(batch_labels.squeeze().numpy())
            sample_embeddings = all_embeddings[sample_labels]
            sample_embeddings = torch.from_numpy(
                sample_embeddings).float().cuda()
            class_num = sample_embeddings.shape[0]

            embeddings_bunch = sample_embeddings.unsqueeze(0).repeat(
                len(batch_features), 1, 1)
            attributes_bunch = pred_embeddings.unsqueeze(0).repeat(
                class_num, 1, 1)
            attributes_bunch = torch.transpose(attributes_bunch, 0, 1)
            relation_pairs = torch.cat((embeddings_bunch, attributes_bunch),
                                       2).view(-1, 624)
            relations = relation_network(relation_pairs).view(-1, class_num)

            re_batch_labels = []
            for label in batch_labels.numpy():
                index = np.argwhere(sample_labels == label)
                re_batch_labels.append(index[0][0])
            re_batch_labels = torch.LongTensor(re_batch_labels).cuda()

            loss_rel = ce(relations, re_batch_labels)

            attribute_network_optim.zero_grad()
            relation_network_optim.zero_grad()

            loss_rel.backward()

            attribute_network_optim.step()
            relation_network_optim.step()

            total_steps += 1

        zsl = compute_accuracy_whole_network(attribute_network,
                                             relation_network,
                                             test_unseen_loader,
                                             test_unseen_label_set,
                                             all_embeddings)
        gzsl_u = compute_accuracy_whole_network(attribute_network,
                                                relation_network,
                                                test_unseen_loader,
                                                test_label_set, all_embeddings)
        gzsl_s = compute_accuracy_whole_network(attribute_network,
                                                relation_network,
                                                test_seen_loader,
                                                test_label_set, all_embeddings)

        H = 2 * gzsl_s * gzsl_u / (gzsl_u + gzsl_s)
        print(
            "Epoch: {:>3} zsl: {:.5f} gzsl_u: {:.5f} gzsl_s: {:.5f} H: {:.5f}".
            format(epoch, zsl, gzsl_u, gzsl_s, H))
示例#11
0
    writer = SummaryWriter(
        log_dir="/home/atheist8E/Earth/Civilization/Alexandria/{}_time_{}".
        format(filename,
               datetime.now().strftime("%Y_%m_%d_%H_%M_%S")))
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    test_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    train_loader, test_loader = ImageNet(args, train_transform, test_transform)
    model = IcarusNet_v1(args, writer).cuda(args.gpu)
    student_criterion = nn.CrossEntropyLoss()
    attention_criterion = nn.MSELoss()
    student_optimizer = torch.optim.SGD(model.parameters(),
                                        lr=args.learning_rate,
                                        weight_decay=0.0001,
                                        momentum=0.1)
    scheduler = StepLR(student_optimizer, step_size=30, gamma=0.1)
    model.fit(train_loader, test_loader, student_criterion,
              attention_criterion, student_optimizer, args.epoch, scheduler)
    writer = model.report()
    writer.close()
示例#12
0
def main():

    print("Load data . . .")
    metatrain_dirs, metaval_dirs, metatest_dirs = get_class_dirs()
    metatrain_dataset = OmniglotOneshotDataset(metatrain_dirs)
    metatrain_loader = DataLoader(metatrain_dataset,
                                  batch_size=CLASS_IN_EP,
                                  shuffle=True)
    metaval_dataset = OmniglotOneshotDataset(metaval_dirs)
    metaval_loader = DataLoader(metaval_dataset,
                                batch_size=CLASS_IN_EP,
                                shuffle=True)
    metatest_dataset = OmniglotOneshotDataset(metaval_dirs)
    metatest_loader = DataLoader(metaval_dataset,
                                 batch_size=CLASS_IN_EP,
                                 shuffle=True)

    print("Build model . . .")
    embed_net = EmbeddingModule()
    rel_net = RelationModule()

    print("Setup training . . .")
    criterion = nn.MSELoss()
    embed_opt = torch.optim.Adam(embed_net.parameters(), lr=LEARNING_RATE)
    rel_opt = torch.optim.Adam(rel_net.parameters(), lr=LEARNING_RATE)
    embed_scheduler = StepLR(embed_opt, step_size=100000, gamma=0.5)
    rel_scheduler = StepLR(rel_opt, step_size=100000, gamma=0.5)
    embed_net.apply(weights_init)
    rel_net.apply(weights_init)
    embed_net.to(device)
    rel_net.to(device)

    print("Training . . .")
    running_loss = 0.0
    for episode in range(1, TRAIN_EPISODES + 1):

        embed_scheduler.step(episode)
        rel_scheduler.step(episode)

        # setup episode
        ep_data = next(iter(metatrain_loader))
        sample = ep_data['sample'].to(device)  # CLASS_IN_EP x 1 x 28 x 28
        query = ep_data['query'].to(
            device)  # CLASS_IN_EP x QUERY_SIZE x 28 x 28
        query = query.view(-1, 1, 28, 28)  # flat_size x 1 x 28 x 28

        # forward pass
        sample_features = embed_net(
            sample
        )  # CLASS_IN_EP x FEATURE_DIM x 1 x 1 (avoid redundant computation)
        query_features = embed_net(query)  # flat_size x FEATURE_DIM x 1 x 1
        combined, score_target = combine_pairs(sample_features, query_features)
        score_target = score_target.to(device)
        score_pred = rel_net(combined)
        loss = criterion(score_pred, score_target)

        # backward pass & update
        rel_net.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(embed_net.parameters(), 0.5)
        nn.utils.clip_grad_norm_(rel_net.parameters(), 0.5)
        embed_opt.step()
        rel_opt.step()

        # print progress
        running_loss += loss.item()
        if episode % PRINT_EVERY == 0:
            print('Episode %d, avg loss: %f' %
                  (episode, running_loss / PRINT_EVERY))
            running_loss = 0.0

        # validate model
        if episode % VALIDATE_EVERY == 0:
            val_accuracy = evaluate_accuracy(embed_net, rel_net, VAL_EPISODES,
                                             metaval_loader)
            print('Validation accuracy: %f' % (val_accuracy))

    print("Testing . . .")
    test_accuracy = evaluate_accuracy(embed_net, rel_net, TEST_EPISODES,
                                      metatest_loader)
    print('Test accuracy: %f' % (test_accuracy))
示例#13
0
def main():
    # step 1: init dataset
    print("init dataset")
    
    dataroot = './data'
    dataset = 'CUB1_data'
    image_embedding = 'res101'
    class_embedding = 'original_att_splits'

    attribute_values = np.load("CUB_200_2011/attribute_values.npy")
    certainty_values = np.load("CUB_200_2011/certainty_values.npy")

    matcontent = sio.loadmat(dataroot + "/" + dataset + "/" + image_embedding + ".mat")
    feature = matcontent['features'].T

    label = matcontent['labels'].astype(int).squeeze() - 1

    matcontent = sio.loadmat(dataroot + "/" + dataset + "/" + class_embedding + ".mat")
    # numpy array index starts from 0, matlab starts from 1
    trainval_loc = matcontent['trainval_loc'].squeeze() - 1
    test_seen_loc = matcontent['test_seen_loc'].squeeze() - 1
    test_unseen_loc = matcontent['test_unseen_loc'].squeeze() - 1
  
    attribute = matcontent['att'].T

    x = feature[trainval_loc] # train_features

    train_attributes = attribute_values[trainval_loc] # train_attribute_values
    train_certainty = certainty_values[trainval_loc] # train_certainty_values

    train_label = label[trainval_loc].astype(int)  # train_label

    att = attribute[train_label] # train attributes

    x_test = feature[test_unseen_loc]  # test_feature
    test_label = label[test_unseen_loc].astype(int) # test_label

    x_test_seen = feature[test_seen_loc]  #test_seen_feature
    test_label_seen = label[test_seen_loc].astype(int) # test_seen_label

    test_id = np.unique(test_label)   # test_id
    att_pro = attribute[test_id]      # test_attribute
    
    
    # train set
    train_features = torch.from_numpy(x) # [7057, 2048]

    train_label = torch.from_numpy(train_label).unsqueeze(1) # [7057, 1]

    train_attributes = torch.from_numpy(train_attributes) # train_attribute_values

    train_certainty = torch.from_numpy(train_certainty) # train_certainty_values

    # attributes
    all_attributes = np.array(attribute) # (200, 312)
    
    attributes = torch.from_numpy(all_attributes)
    # test set
    test_features = torch.from_numpy(x_test) # [2967, 2048]

    test_label = torch.from_numpy(test_label).unsqueeze(1) # [2967, 1]

    testclasses_id = np.array(test_id) # (50, )

    test_attributes = torch.from_numpy(att_pro).float() # (50, 312)

    
    test_seen_features = torch.from_numpy(x_test_seen) # [1764, 2048]
    
    test_seen_label = torch.from_numpy(test_label_seen) # [1764]
    
    #train_data = TensorDataset(train_features,train_label)
    train_data = IntegratedDataset(train_features, train_label, train_attributes, train_certainty)    

    # init network
    print("init networks")
    attribute_network = AttributeNetwork(2048,1200,624).cuda()
    relation_network = RelationNetwork(936, 400).cuda()

    mse = nn.MSELoss().cuda()
    ce = nn.CrossEntropyLoss(weight=torch.FloatTensor([0.1, 1.])).cuda()
    nll = torch.nn.NLLLoss(weight=torch.FloatTensor([0.1, 1.])).cuda()
    #mse = nn.BCELoss().cuda()

    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)

    if MODEL == 1:
        attribute_network_optim = torch.optim.Adam(attribute_network.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
        attribute_network_scheduler = StepLR(attribute_network_optim, step_size=30000, gamma=0.5)

        # if os.path.exists("./models/zsl_cub2_attribute_network_v35.pkl"):
        #     attribute_network.load_state_dict(torch.load("./models/zsl_cub_attribute_network_v35.pkl"))
        #     print("load attribute network success")
        # if os.path.exists("./models/zsl_cub2_relation_network_v35.pkl"):
        #     relation_network.load_state_dict(torch.load("./models/zsl_cub_relation_network_v35.pkl"))
        #     print("load relation network success")

        print("training...")
        #last_accuracy = 0.0
        total_steps = 0
        best_loss_att = 10000
        best_loss_cer = 10000

        for epoch in range(EPOCHS):
            attribute_network.train()
            for i, (batch_features, batch_labels, batch_att, batch_cer) in enumerate(train_loader):

                batch_features, batch_att, batch_cer = batch_features.float().cuda(), batch_att.float().cuda(), batch_cer.float().cuda()
                
                attribute_network_scheduler.step(total_steps)
                
                #relation_network_scheduler.step(episode)

                #sample_labels = set(batch_labels.squeeze().numpy().tolist())
                
                #sample_attributes = torch.Tensor([all_attributes[i] for i in sample_labels])
                
                #class_num = sample_attributes.shape[0]
                pred_embeddings = attribute_network(batch_features)

                cat_batch_att = batch_att.long().view(-1)
                cat_pred_att = torch.cat((pred_embeddings[:,0:312].unsqueeze(2), 1 - pred_embeddings[:,0:312].unsqueeze(2)), 2).view(-1, 2)
                
                #loss_att = ce(cat_pred_att, cat_batch_att)
                loss_att = nll(torch.log(cat_pred_att), cat_batch_att)
                loss_cer = mse(pred_embeddings[:,312:624], batch_cer)
                loss_net1 = loss_att + loss_cer

                attribute_network_optim.zero_grad()
                loss_net1.backward()
                attribute_network_optim.step()

                total_steps += 1

            #if epoch % 50 == 0:
            loss_att_mean, loss_cer_mean, acc_att, acc_cer = evaluate_attribute_network(attribute_network, nll, mse, train_loader)
            print("Epoch: {:>3} loss_att: {:.5f} loss_cer: {:.5f}".format(epoch, loss_att_mean, loss_cer_mean))
            print("Epoch: {:>3} acc_att: {:.5f} acc_cer: {:.5f}".format(epoch, acc_att, acc_cer))

            if best_loss_att > loss_att_mean and best_loss_cer > loss_cer_mean:
                torch.save(attribute_network.state_dict(), 'models/attribute_network.pt')
            if best_loss_att > loss_att_mean:
                best_loss_att = loss_att_mean
            if best_loss_cer > loss_cer_mean:
                best_loss_cer = loss_cer_mean

    else:
        attribute_network.load_state_dict(torch.load("./models/attribute_network.pt"))
        print("load attribute network success")

        loss_att_mean, loss_cer_mean = evaluate_attribute_network(attribute_network, ce, train_loader)
        print("loss_att: {:.5f} loss_cer: {:.5f}".format(loss_att_mean, loss_cer_mean))

        assert 1 == 0
        relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE)
        relation_network_scheduler = StepLR(relation_network_optim, step_size=30000, gamma=0.5)



        
        batch_features = Variable(batch_features).cuda(GPU).float()  # 32*2048
        sample_features = attribute_network(Variable(sample_attributes).cuda(GPU)) #c*2048

        sample_features_ext = sample_features.unsqueeze(0).repeat(BATCH_SIZE,1,1)
        batch_features_ext = batch_features.unsqueeze(0).repeat(class_num,1,1)
        batch_features_ext = torch.transpose(batch_features_ext,0,1)
        
        #print(sample_features_ext)
        #print(batch_features_ext)
        relation_pairs = torch.cat((sample_features_ext,batch_features_ext),2).view(-1,4096)
        relations = relation_network(relation_pairs).view(-1,class_num)
        #print(relations)

        # re-build batch_labels according to sample_labels
        sample_labels = np.array(sample_labels)
        re_batch_labels = []
        for label in batch_labels.numpy():
            index = np.argwhere(sample_labels==label)
            re_batch_labels.append(index[0][0])
        re_batch_labels = torch.LongTensor(re_batch_labels)

        # loss
        # relations = nn.functional.softmax(relations, 1)
        mse = nn.MSELoss().cuda(GPU)
        one_hot_labels = Variable(torch.zeros(BATCH_SIZE, class_num).scatter_(1, re_batch_labels.view(-1,1), 1)).cuda(GPU)
        loss = mse(relations,one_hot_labels)

        # update
        attribute_network.zero_grad()
        relation_network.zero_grad()

        loss.backward()

        attribute_network_optim.step()
        relation_network_optim.step()

        # if (episode+1)%100 == 0:
        #         print("episode:",episode+1,"loss",loss.data[0])

        if (episode + 1) % 2000 == 0:
            # test
            print("Testing...")

            def compute_accuracy(test_features,test_label,test_id,test_attributes):
                
                test_data = TensorDataset(test_features,test_label)
                test_batch = 32
                test_loader = DataLoader(test_data,batch_size=test_batch,shuffle=False)
                total_rewards = 0
                # fetch attributes
                sample_labels = test_id
                sample_attributes = test_attributes
                class_num = sample_attributes.shape[0]
                test_size = test_features.shape[0]
                
                print("class num:",class_num)
                predict_labels_total = []
                re_batch_labels_total = []
                
                for batch_features,batch_labels in test_loader:

                    batch_size = batch_labels.shape[0]

                    batch_features = Variable(batch_features).cuda(GPU).float()  # 32*1024
                    sample_features = attribute_network(Variable(sample_attributes).cuda(GPU).float())

                    sample_features_ext = sample_features.unsqueeze(0).repeat(batch_size,1,1)
                    batch_features_ext = batch_features.unsqueeze(0).repeat(class_num,1,1)
                    batch_features_ext = torch.transpose(batch_features_ext,0,1)

                    relation_pairs = torch.cat((sample_features_ext,batch_features_ext),2).view(-1,4096)
                    relations = relation_network(relation_pairs).view(-1,class_num)

                    # re-build batch_labels according to sample_labels

                    re_batch_labels = []
                    for label in batch_labels.numpy():
                        index = np.argwhere(sample_labels==label)
                        re_batch_labels.append(index[0][0])
                    re_batch_labels = torch.LongTensor(re_batch_labels)

                    _,predict_labels = torch.max(relations.data,1)
                    predict_labels = predict_labels.cpu().numpy()
                    re_batch_labels = re_batch_labels.cpu().numpy()
                    
                    predict_labels_total = np.append(predict_labels_total, predict_labels)
                    re_batch_labels_total = np.append(re_batch_labels_total, re_batch_labels)

                # compute averaged per class accuracy    
                predict_labels_total = np.array(predict_labels_total, dtype='int')
                re_batch_labels_total = np.array(re_batch_labels_total, dtype='int')
                unique_labels = np.unique(re_batch_labels_total)
                acc = 0
                for l in unique_labels:
                    idx = np.nonzero(re_batch_labels_total == l)[0]
                    acc += accuracy_score(re_batch_labels_total[idx], predict_labels_total[idx])
                acc = acc / unique_labels.shape[0]
                return acc
            
            zsl_accuracy = compute_accuracy(test_features,test_label,test_id,test_attributes)
            gzsl_unseen_accuracy = compute_accuracy(test_features,test_label,np.arange(200),attributes)
            gzsl_seen_accuracy = compute_accuracy(test_seen_features,test_seen_label,np.arange(200),attributes)
            
            H = 2 * gzsl_seen_accuracy * gzsl_unseen_accuracy / (gzsl_unseen_accuracy + gzsl_seen_accuracy)
            
            print('zsl:', zsl_accuracy)
            print('gzsl: seen=%.4f, unseen=%.4f, h=%.4f' % (gzsl_seen_accuracy, gzsl_unseen_accuracy, H))
            
            if zsl_accuracy > last_accuracy:
            

                # save networks
                torch.save(attribute_network.state_dict(),"./models/zsl_cub_attribute_network_v35.pkl")
                torch.save(relation_network.state_dict(),"./models/zsl_cub_relation_network_v35.pkl")

                print("save networks for episode:",episode)
                
                last_accuracy = zsl_accuracy
示例#14
0
class XNRIDECIns(Instructor):
    """
    Train the decoder given the ground truth relations.
    """
    def __init__(self, model: torch.nn.DataParallel, data: dict,
                 es: np.ndarray, cmd):
        """
        Args:
            model: an auto-encoder
            data: train / val /test set
            es: edge list
            cmd: command line parameters
        """
        super(XNRIDECIns, self).__init__(cmd)
        self.model = model
        self.data = {
            key: TensorDataset(value[0], value[1])
            for key, value in data.items()
        }
        self.es = torch.LongTensor(es)
        # number of nodes
        self.size = cmd.size
        self.batch_size = cmd.batch
        # optimizer
        self.opt = optim.Adam(self.model.parameters(), lr=cfg.lr)
        # learning rate scheduler, same as in NRI
        self.scheduler = StepLR(self.opt,
                                step_size=cfg.lr_decay,
                                gamma=cfg.gamma)

    def train(self):
        # use the loss as the metric for model selection, default: +\infty
        val_best = np.inf
        # path to save the current best model
        prefix = '/'.join(cfg.log.split('/')[:-1])
        name = '{}/best.pth'.format(prefix)
        for epoch in range(1, 1 + self.cmd.epochs):
            self.model.train()
            # shuffle the data at each epoch
            data = self.load_data(self.data['train'], self.batch_size)
            loss_a = 0.
            N = 0.
            for adj, states in data:
                if cfg.gpu:
                    adj = adj.cuda()
                    states = states.cuda()
                scale = len(states) / self.batch_size
                # N: number of samples, equal to the batch size with possible exception for the last batch
                N += scale
                loss_a += scale * self.train_nri(states, adj)
            loss_a /= N
            self.log.info('epoch {:03d} loss {:.3e}'.format(epoch, loss_a))
            losses = self.report('val', [cfg.M])

            val_cur = losses[0]
            if val_cur < val_best:
                # update the current best model when approaching a lower loss
                val_best = val_cur
                torch.save(self.model.module.state_dict(), name)

            # learning rate scheduling
            self.scheduler.step()
        if self.cmd.epochs > 0:
            self.model.module.load_state_dict(torch.load(name))
        _ = self.test('test', 20)

    def report(self, name: str, Ms: list) -> list:
        """
        Evaluate the mean squared errors.

        Args:
            name: 'train' / 'val' / 'test'
            Ms: [...], each element is a number of steps to predict
        
        Return:
            mses: [...], mean squared errors over all steps
        """
        mses = []
        for M in Ms:
            mse, ratio = self.evaluate(self.data[name], M)
            mses.append(mse)
            self.log.info('{} M {:02d} mse {:.3e} ratio {:.4f}'.format(
                name, M, mse, ratio))
        return mses

    def train_nri(self, states: Tensor, adj: Tensor) -> Tensor:
        """
        Args:
            states: [batch, step, node, dim], observed node states
            adj: [batch, E, K], ground truth interacting relations

        Return:
            loss: reconstruction loss
        """
        output = self.model.module.predict_states(
            states,
            one_hot(adj.transpose(0, 1)).float(), cfg.M)
        loss = nll_gaussian(output, states[:, 1:], 5e-5)
        self.optimize(self.opt, loss * cfg.scale)
        return loss

    def evaluate(self, test, M: int):
        """
        Evaluate related metrics to monitor the training process.

        Args:
            test: data set to be evaluted
            M: number of steps to predict

        Return: 
            mse: mean square error over all steps
            ratio: relative root mean squared error
        """
        mse, ratio = [], []
        data = self.load_data(test, self.batch_size)
        N = 0.
        with torch.no_grad():
            for adj, states in data:
                if cfg.gpu:
                    adj = adj.cuda()
                    states = states.cuda()
                states_dec = states[:, -cfg.train_steps:, :, :]
                target = states_dec[:, 1:]

                output = self.model.module.predict_states(
                    states_dec,
                    one_hot(adj.transpose(0, 1)).float(), M)
                # scale all metrics to match the batch size
                scale = len(states) / self.batch_size
                N += scale

                mse.append(scale * mse_loss(output, target).data)
                ratio.append(scale * (((output - target)**2).sum(-1).sqrt() /
                                      (target**2).sum(-1).sqrt()).mean())
        mse = sum(mse) / N
        ratio = sum(ratio) / N
        return mse, ratio

    def test(self, name: str, M: int):
        """
        Evaluate related metrics to measure the model performance.
        The biggest difference between this function and evalute() is that, the mses are evaluated at each step.

        Args:
            name: 'train' / 'val' / 'test'
            M: number of steps to predict

        Return:
            mse_multi: mse at each step
        """
        """
        mses: mean square error over all steps
        ratio: relative root mean squared error
        mse_multi: mse at each step
        """
        mse_multi, mses, ratio = [], [], []
        data = self.load_data(self.data[name], self.batch_size)
        N = 0.
        with torch.no_grad():
            for adj, states in data:
                if cfg.gpu:
                    adj = adj.cuda()
                    states = states.cuda()
                states_dec = states[:, -cfg.train_steps:, :, :]
                target = states_dec[:, 1:]

                output = self.model.module.predict_states(
                    states_dec,
                    one_hot(adj.transpose(0, 1)).float(), cfg.M)
                # scale all metrics to match the batch size
                scale = len(states) / self.batch_size
                N += scale

                mses.append(scale * mse_loss(output, target).data)
                ratio.append(scale * (((output - target)**2).sum(-1).sqrt() /
                                      (target**2).sum(-1).sqrt()).mean())

                states_dec = states[:, cfg.train_steps:cfg.train_steps + M +
                                    1, :, :]
                target = states_dec[:, 1:]

                output = self.model.module.predict_states(
                    states_dec,
                    one_hot(adj.transpose(0, 1)).float(), M)
                mse = ((output - target)**2).mean(dim=(0, 2, -1))
                mse *= scale
                mse_multi.append(mse)
        mses = sum(mses) / N
        mse_multi = sum(mse_multi) / N
        ratio = sum(ratio) / N
        self.log.info('{} M {:02d} mse {:.3e} ratio {:.4f}'.format(
            name,
            M,
            mses,
            ratio,
        ))
        msteps = ','.join(['{:.3e}'.format(i) for i in mse_multi])
        self.log.info(msteps)
        return mse_multi
def main(DATASET,
         LABELS,
         CLASS_IDS,
         BATCH_SIZE,
         ANNOTATION_FILE,
         SEQ_SIZE=16,
         STEP=16,
         fstep=1,
         nstrokes=-1,
         N_EPOCHS=25):
    '''
    Extract sequence features from AutoEncoder.
    
    Parameters:
    -----------
    DATASET : str
        path to the video dataset
    LABELS : str
        path containing stroke labels
    CLASS_IDS : str
        path to txt file defining classes, similar to THUMOS
    BATCH_SIZE : int
        size for batch of clips
    SEQ_SIZE : int
        no. of frames in a clip (min. 16 for 3D CNN extraction)
    STEP : int
        stride for next example. If SEQ_SIZE=16, STEP=8, use frames (0, 15), (8, 23) ...
    partition : str
        'all' / 'train' / 'test' / 'val' : Videos to be considered
    nstrokes : int
        partial extraction of features (do not execute for entire dataset)
    
    Returns:
    --------
    trajectories, stroke_names
    
    '''
    ###########################################################################

    attn_utils.seed_everything(1234)

    if not os.path.isdir(log_path):
        os.makedirs(log_path)

    # Read the strokes
    # Divide the highlight dataset files into training, validation and test sets
    train_lst, val_lst, test_lst = autoenc_utils.split_dataset_files(DATASET)
    print("No. of training videos : {}".format(len(train_lst)))

    ft_path, ft_path_val, ft_path_test = [], [], []
    for i, ft_dir in enumerate(feat_path):
        print("Feature : {}".format(ft_dir))
        features, stroke_names_id = attn_utils.read_feats(
            ft_dir, feat[i], snames[i])
        # get matrix of features from dictionary (N, vec_size)
        vecs = []
        for key in sorted(list(features.keys())):
            vecs.append(features[key])
        vecs = np.vstack(vecs)

        vecs[np.isnan(vecs)] = 0
        vecs[np.isinf(vecs)] = 0

        #fc7 layer output size (4096)
        INP_VEC_SIZE = vecs.shape[-1]
        print("INP_VEC_SIZE = ", INP_VEC_SIZE)

        km_filepath = os.path.join(log_path, km_filename + "_F" + str(i + 1))
        #    # Uncomment only while training.
        if not os.path.isfile(km_filepath + "_C" + str(cluster_size) + ".pkl"):
            km_model = make_codebook(vecs, cluster_size)  #, model_type='gmm')
            ##    # Save to disk, if training is performed
            print("Writing the KMeans models to disk...")
            pickle.dump(
                km_model,
                open(km_filepath + "_C" + str(cluster_size) + ".pkl", "wb"))
        else:
            # Load from disk, for validation and test sets.
            km_model = pickle.load(
                open(km_filepath + "_C" + str(cluster_size) + ".pkl", 'rb'))

        print("Create numpy one hot representation for train features...")
        onehot_feats = create_bovw_onehot(features, stroke_names_id, km_model)

        ft_path.append(
            os.path.join(
                log_path,
                "F" + str(i + 1) + "_C" + str(cluster_size) + "_train.pkl"))
        with open(ft_path[-1], "wb") as fp:
            pickle.dump(onehot_feats, fp)

    ###########################################################################

        features_val, stroke_names_id_val = attn_utils.read_feats(
            ft_dir, feat_val[i], snames_val[i])

        print("Create numpy one hot representation for val features...")
        onehot_feats_val = create_bovw_onehot(features_val,
                                              stroke_names_id_val, km_model)

        ft_path_val.append(
            os.path.join(
                log_path,
                "F" + str(i + 1) + "_C" + str(cluster_size) + "_val.pkl"))
        with open(ft_path_val[-1], "wb") as fp:
            pickle.dump(onehot_feats_val, fp)

    ###########################################################################

        features_test, stroke_names_id_test = attn_utils.read_feats(
            ft_dir, feat_test[i], snames_test[i])

        print("Create numpy one hot representation for test features...")
        onehot_feats_test = create_bovw_onehot(features_test,
                                               stroke_names_id_test, km_model)

        ft_path_test.append(
            os.path.join(
                log_path,
                "F" + str(i + 1) + "_C" + str(cluster_size) + "_test.pkl"))
        with open(ft_path_test[-1], "wb") as fp:
            pickle.dump(onehot_feats_test, fp)

    ###########################################################################
    # Create a Dataset
    train_dataset = StrokeMultiFeaturePairsDataset(ft_path,
                                                   train_lst,
                                                   DATASET,
                                                   LABELS,
                                                   CLASS_IDS,
                                                   frames_per_clip=SEQ_SIZE,
                                                   extracted_frames_per_clip=2,
                                                   step_between_clips=STEP,
                                                   future_step=fstep,
                                                   train=True)
    val_dataset = StrokeMultiFeaturePairsDataset(ft_path_val,
                                                 val_lst,
                                                 DATASET,
                                                 LABELS,
                                                 CLASS_IDS,
                                                 frames_per_clip=SEQ_SIZE,
                                                 extracted_frames_per_clip=2,
                                                 step_between_clips=STEP,
                                                 future_step=fstep,
                                                 train=False)

    #    # created weighted Sampler for class imbalance
    #    samples_weight = attn_utils.get_sample_weights(train_dataset, labs_keys, labs_values,
    #                                                   train_lst)
    #    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=BATCH_SIZE,
                              shuffle=True)
    #                              sampler=sampler, worker_init_fn=np.random.seed(12))

    val_loader = DataLoader(dataset=val_dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=False)

    data_loaders = {"train": train_loader, "test": val_loader}

    ###########################################################################
    # get labels
    labs_keys, labs_values = attn_utils.get_cluster_labels(ANNOTATION_FILE)
    num_classes = len(list(set(labs_values)))

    ###########################################################################

    # load model and set loss function
    ntokens = cluster_size * len(feat_path)  # the size of vocabulary
    emsize = 200  # embedding dimension
    nhead = 2  # the number of heads in the multiheadattention models
    dropout = 0.2  # the dropout value
    model = tt.TransformerModel(ntokens, emsize, nhead, nhid, nlayers,
                                dropout).to(device)

    #    model = load_weights(log_path, model, N_EPOCHS,
    #                                    "S"+str(SEQ_SIZE)+"C"+str(cluster_size)+"_SGD")

    # Setup the loss fxn
    criterion = nn.CrossEntropyLoss()
    model = model.to(device)
    #    print("Params to learn:")
    params_to_update = []
    for name, param in model.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t", name)

    # Observe that all parameters are being optimized


#    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    # Decay LR by a factor of 0.1 every 7 epochs
    scheduler = StepLR(optimizer, step_size=15, gamma=0.1)

    #    lr = 5.0 # learning rate
    #    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    #    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)
    ###########################################################################
    # Training the model

    #    start = time.time()
    #
    #    model = train_model(model, data_loaders, criterion, optimizer, scheduler,
    #                        labs_keys, labs_values, num_epochs=N_EPOCHS)
    #
    #    end = time.time()
    #
    ##    # save the best performing model
    #    save_model_checkpoint(log_path, model, N_EPOCHS,
    #                                     "S"+str(SEQ_SIZE)+"C"+str(cluster_size)+"_SGD")
    # Load model checkpoints
    model = load_weights(
        log_path, model, N_EPOCHS,
        "S" + str(SEQ_SIZE) + "C" + str(cluster_size) + "_SGD")

    #    print("Total Execution time for {} epoch : {}".format(N_EPOCHS, (end-start)))

    ###########################################################################

    #    acc = predict(features_val, stroke_names_id_val, model, data_loaders, labs_keys,
    #                  labs_values, SEQ_SIZE, phase='test')

    ###########################################################################

    # Extract attention model features
    if not os.path.isfile(
            os.path.join(log_path + "/s" + str(fstep), "trans_feats.pkl")):
        if not os.path.exists(log_path + "/s" + str(fstep)):
            os.makedirs(log_path + "/s" + str(fstep))
        #    # Extract Grid OF / HOOF features {mth = 2, and vary nbins}
        print("Training extraction ... ")
        feats_dict, stroke_names = extract_trans_feats(model,
                                                       DATASET,
                                                       LABELS,
                                                       CLASS_IDS,
                                                       BATCH_SIZE,
                                                       SEQ_SIZE,
                                                       2,
                                                       partition='train',
                                                       nstrokes=nstrokes,
                                                       base_name=log_path)

        with open(
                os.path.join(log_path + "/s" + str(fstep), "trans_feats.pkl"),
                "wb") as fp:
            pickle.dump(feats_dict, fp)
        with open(
                os.path.join(log_path + "/s" + str(fstep), "trans_snames.pkl"),
                "wb") as fp:
            pickle.dump(stroke_names, fp)

    if not os.path.isfile(
            os.path.join(log_path + "/s" + str(fstep), "trans_feats_val.pkl")):
        print("Validation extraction ....")
        feats_dict_val, stroke_names_val = extract_trans_feats(
            model,
            DATASET,
            LABELS,
            CLASS_IDS,
            BATCH_SIZE,
            SEQ_SIZE,
            2,
            partition='val',
            nstrokes=nstrokes,
            base_name=log_path)

        with open(
                os.path.join(log_path + "/s" + str(fstep),
                             "trans_feats_val.pkl"), "wb") as fp:
            pickle.dump(feats_dict_val, fp)
        with open(
                os.path.join(log_path + "/s" + str(fstep),
                             "trans_snames_val.pkl"), "wb") as fp:
            pickle.dump(stroke_names_val, fp)

    if not os.path.isfile(
            os.path.join(log_path + "/s" + str(fstep),
                         "trans_feats_test.pkl")):
        print("Testing extraction ....")
        feats_dict_val, stroke_names_val = extract_trans_feats(
            model,
            DATASET,
            LABELS,
            CLASS_IDS,
            BATCH_SIZE,
            SEQ_SIZE,
            2,
            partition='test',
            nstrokes=nstrokes,
            base_name=log_path)

        with open(
                os.path.join(log_path + "/s" + str(fstep),
                             "trans_feats_test.pkl"), "wb") as fp:
            pickle.dump(feats_dict_val, fp)
        with open(
                os.path.join(log_path + "/s" + str(fstep),
                             "trans_snames_test.pkl"), "wb") as fp:
            pickle.dump(stroke_names_val, fp)

    # call count_paramters(model)  for displaying total no. of parameters
    print("#Parameters : {} ".format(autoenc_utils.count_parameters(model)))
    return 0
示例#16
0
文件: Main.py 项目: LordVenter/MNIST
        self.l2 = Linear(128, 10)
        self.relu = ReLU()

    def forward(self, x):
        x = dropout2d(self.pool(self.relu(self.conv1(x.view(-1, 1, 28, 28)))),
                      .5)
        x = dropout2d(self.pool(self.relu(self.conv2(x))), 0.5)
        x = self.relu(self.l1(x.view(-1, self.feature_size)))
        x = self.relu(self.l2(x))
        return x


net = Net().to('cuda')
optimizer = optim.Adam(net.parameters(), 0.0001)

scheduler = StepLR(optimizer, step_size=1, gamma=0.7)


def train(model, device, train_loader, optimizer, epoch, log_interval=20):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = mse_loss(output, to_categorical(target, 10))
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print(
                f'Train Epoch: {epoch+1} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({round(100 * batch_idx / len(train_loader))}%)]\tLoss: {loss.item()}'
            )
示例#17
0
class Trainer:
    def __init__(self, args, model):
        self.verbose = args.verbose
        self.model = model
        self.path = f'models/{args.model}.pth'
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if self.verbose:
            logger.debug('model {} loaded'.format(args.model))

        self.num_epochs = args.epochs
        self.train_loader, self.test_loader = self.load_dataset(args.train_path)
        self.criterion = torch.nn.CrossEntropyLoss()
        learning_rate = 0.1
        self.optimizer = torch.optim.SGD(model.classifier.parameters(), lr=learning_rate, momentum=0.9, nesterov=True)
        #self.optimizer = torch.optim.Adam(model.classifier.parameters(), lr=0.0001)

        self.scheduler = StepLR(self.optimizer, step_size=10, gamma=0.1)

        logger.info(self.optimizer)
        logger.info(self.criterion)
        self.model.cuda()

    def load_dataset(self, path, batch_size=64, num_workers=16, pin_memory=True, valid_size=.2):
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        transform = transforms.Compose([
            transforms.Scale(256),
            transforms.RandomSizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        train_data = datasets.ImageFolder(path, transform)
        test_data = datasets.ImageFolder(path, transform)
        num_train = len(train_data)
        indices = list(range(num_train))
        split = int(np.floor(valid_size * num_train))
        np.random.seed(1234)

        np.random.shuffle(indices)
        train_idx, test_idx = indices[split:], indices[:split]
        train_sampler = SubsetRandomSampler(train_idx)
        test_sampler = SubsetRandomSampler(test_idx)
        train_loader = torch.utils.data.DataLoader(train_data,
                                                   sampler=train_sampler,
                                                   batch_size=batch_size,
                                                   num_workers=num_workers,
                                                   pin_memory=pin_memory
                                                   )
        test_loader = torch.utils.data.DataLoader(test_data,
                                                  sampler=test_sampler,
                                                  batch_size=batch_size,
                                                  num_workers=num_workers,
                                                  pin_memory=pin_memory
                                                  )
        return train_loader, test_loader

    def train(self):
        print_every = 10
        train_losses, test_losses = [], []
        test_accuracy = []
        logger.info('start train')
        total_steps = len(self.train_loader)
        logger.info(f'test_loader batches: {len(self.test_loader)}')
        logger.info(f'train_loader batches: {len(self.train_loader)}')

        for epoch in range(self.num_epochs):
            running_loss = 0
            test_loss = 0
            test_acc = 0
            steps = 0
            self.scheduler.step()
            logger.info(f'Epoch: {epoch} LR: {self.scheduler.get_lr()}')

            for inputs, labels in self.train_loader:
                steps += 1
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                self.optimizer.zero_grad()

                # forward + backward + optimize
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()
                running_loss += loss.item()
                logger.info(f'step: {steps} / {total_steps} loss: {loss}')

                # validate
                if steps % print_every == 0:
                    print('steps', steps)
                    for p in self.model.parameters():
                        if p.grad is not None:
                            print(torch.sum(p.grad.data))

                    with torch.no_grad():
                        iterator = iter(self.test_loader)
                        inputs, labels = iterator.next()
                        inputs, labels = inputs.to(self.device), labels.to(self.device)
                        outputs = self.model(inputs)
                        batch_loss = self.criterion(outputs, labels)
                        test_loss += batch_loss.item()
                        ps = torch.exp(outputs)
                        top_p, top_class = ps.topk(1, dim=1)
                        equals = top_class == labels.view(*top_class.shape)
                        accuracy = torch.mean(equals.type(torch.FloatTensor)).item()
                        test_acc += accuracy
                        logger.info(f'val_accuracy: {accuracy}  val_loss: {batch_loss}')

            train_losses.append(running_loss / len(self.train_loader))
            test_losses.append(test_loss / len(self.train_loader) * print_every)
            test_accuracy.append(test_acc / len(self.train_loader) * print_every)

            logger.info(f"Epoch {epoch + 1}/{self.num_epochs}.. "
                        f"Train loss: {running_loss / total_steps}.. ")
            torch.save(self.model, self.path.replace('.pth', f'_epoch_{epoch}.pth'))

        self.train_losses = train_losses
        self.test_losses = test_losses
        self.accuracy = test_accuracy

        torch.save(self.model, self.path)

    def view(self):
        plt.plot(self.train_losses, label='Training loss')
        plt.plot(self.test_losses, label='Validation loss')
        plt.plot(self.accuracy, label='Validation accuracy')

        plt.legend(frameon=False)
        plt.savefig('training.png')
        plt.show()
示例#18
0
def train():

    forward_times = counter_t = vis_count = 0

    dataset = BraTS_FLAIR(csv_dir,
                          hgg_dir,
                          transform=None,
                          train_size=train_size)
    dataset_val = BraTS_FLAIR_val(csv_dir, hgg_dir)  #val
    data_loader = DataLoader(dataset, batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(dataset_val,
                            batch_size_val,
                            shuffle=True,
                            num_workers=2)
    loaders = {'train': data_loader, 'val': val_loader}

    top_models = [(0, 10000)] * 5  # (epoch,loss)
    worst_val = 10000  # init val save loop

    loss_fn = torch.nn.CrossEntropyLoss(weight=loss_weights)
    soft_max = torch.nn.Softmax(dim=1)
    bce_classify = torch.nn.BCELoss()
    sigmoid = torch.nn.Sigmoid()

    opt = torch.optim.Adam(model.parameters(), lr=init_lr)
    scheduler = StepLR(opt, step_size=opt_step_size, gamma=opt_gamma)  #**
    # multisteplr 0.5,

    opt.zero_grad()
    counter = 0
    for epoch in range(epochs):
        counter += 1
        for e in loaders:
            if e == 'train':
                counter_t += 1
                model.train()
                grad = True  #
            else:
                model.eval()
                grad = False

            with torch.set_grad_enabled(grad):
                for idx, batch_data in enumerate(loaders[e]):
                    torch.cuda.empty_cache()
                    batch_input = Variable(
                        batch_data['img'].float()).cuda()  #.to(device)
                    batch_gt_mask = Variable(
                        batch_data['mask'].float()).cuda()  #.to(device)
                    batch_seg_orig = Variable(
                        batch_data['seg_orig']).cuda()  #.to(device)
                    eg_name = batch_data['eg_name']
                    classify_tars = batch_data['classify'].cuda()

                    sup_block3 = resize(
                        np.array(batch_data['seg_orig']),
                        [batch_size, 172 / 4, 172 / 4, 128 / 4],
                        order=0,
                        mode='constant',
                        preserve_range=True)
                    sup_block3 = torch.tensor(sup_block3).cuda()
                    one_hot_b3 = one_hot(sup_block3).transpose(
                        0, 1
                    )  # transpose cause one_hot made for (h,w,d) not (bs,h..)
                    #                     print('b3',sup_block3.size(),one_hot_b3.size())

                    sup_block4 = resize(
                        np.array(batch_data['seg_orig']),
                        [batch_size, 172 / 2, 172 / 2, 128 / 2],
                        order=0,
                        mode='constant',
                        preserve_range=True)
                    sup_block4 = torch.tensor(sup_block4).cuda()
                    one_hot_b4 = one_hot(sup_block4).transpose(0, 1)
                    #                     print('b4',sup_block4.size(),one_hot_b4.size())

                    classify, b3, b4, pred_mask = model(batch_input)
                    if e == 'train': forward_times += 1

                    ce = loss_fn(pred_mask, batch_seg_orig.long())
                    soft_mask = soft_max(pred_mask)
                    dice = dice_loss(soft_mask, batch_gt_mask)

                    a, b, c, d = dice_loss_classes(soft_mask, batch_gt_mask)

                    lossnet = ce + (a + b + c + d) / 4  # +dice

                    #                     print('lb3',b3.size(),sup_block3.size())
                    ceb3 = loss_fn(b3, sup_block3.long())
                    soft_mask_b3 = soft_max(b3)
                    a1, b1, c1, d1 = dice_loss_classes(soft_mask_b3,
                                                       one_hot_b3.float())
                    lossb3 = ceb3 + (a1 + b1 + c1 + d1) / 4

                    ceb4 = loss_fn(b4, sup_block4.long())
                    soft_mask_b4 = soft_max(b4)
                    a2, b2, c2, d2 = dice_loss_classes(soft_mask_b4,
                                                       one_hot_b4.float())
                    lossb4 = ceb4 + (a2 + b2 + c2 + d2) / 4

                    loss = beta * lossb4
                    if counter > 30: loss += alpha * lossb3
                    if counter > 80: loss += lossnet
                    if counter > 120:
                        loss += thetha * bce_classify(sigmoid(classify),
                                                      classify_tars.float())

                    if epoch > epoch_magnifiy:
                        if b > 0.5 or c > 0.5 or d > 0.5:
                            loss = 3 * loss
                            print('Magnified')

                    # if epoch>epoch_trouble:
                    #     if loss>1.5*loss_moving_avg:
                    #         loss=3*loss
                    #         print('Trouble')

                    print('sums', lossnet.item(), loss.item())

                    print('Dice Losses: ', a.item(), b.item(), c.item(),
                          d.item())
                    if e == 'train':
                        Lc.append(ce.item())
                        cross_moving_avg = sum(Lc) / len(Lc)
                        #                         Ldc.append(np.array([a.item(),b.item(),c.item(),d.item()]))
                        #                         Ld.append(dice.item()); dice_moving_avg=sum(Ld)/len(Ld)
                        L.append(loss.item())
                        loss_moving_avg = sum(L) / len(L)
                        loss.backward()
                        print('Epoch: ', epoch + 1, ' Batch: ', idx + 1,
                              ' lr: ',
                              scheduler.get_lr()[-1], ' CE: ',
                              cross_moving_avg, ' Loss:', loss_moving_avg)
                        if forward_times == grad_accu_times:
                            opt.step()
                            opt.zero_grad()
                            forward_times = 0
                            print('\nUpdate weights ... \n')

                        writer.add_scalar('Total Train Loss', loss.item(),
                                          counter_t)
                        writer.add_scalar('Target Train Loss', lossnet.item(),
                                          counter_t)
                        writer.add_scalar('Train CE', ce.item(), counter_t)
                        #                         writer.add_scalar('Train Dice', dice.item() , counter_t)
                        #                         writer.add_scalar('D1', a.item(), counter_t)
                        writer.add_scalar('D2', b.item(), counter_t)
                        writer.add_scalar('D3', c.item(), counter_t)

                        writer.add_scalar('D4', d.item(), counter_t)
                        writer.add_scalar('lossb3', lossb3.item(), counter_t)
                        writer.add_scalar('lossb4', lossb4.item(), counter_t)
                        try:
                            writer.add_scalar('loss_classify',
                                              loss_classify.item(), counter_t)
                        except:
                            writer.add_scalar('loss_classify', 0, counter_t)
                        writer.add_scalar('Lr',
                                          scheduler.get_lr()[-1], counter_t)
                        torch.cuda.empty_cache()
                        if epoch > log_epoch:

                            if b > 0.5 or c > 0.5 or d > 0.5 or eg_name[
                                    0] in hotlist.keys(
                                    ):  #hotlist,start after N epochs,keep tracking once added
                                print('hotlist culprit')
                                if eg_name[0] not in hotlist.keys():
                                    hotlist[eg_name[0]] = [[ce.item()],
                                                           [b.item()],
                                                           [c.item()],
                                                           [d.item()]]
                                else:
                                    hotlist[eg_name[0]][0].append(ce.item())
                                    hotlist[eg_name[0]][1].append(b.item())
                                    hotlist[eg_name[0]][2].append(c.item())
                                    hotlist[eg_name[0]][3].append(d.item())
                        if epoch > log_epoch:
                            if epoch % 20 == 0:
                                print('biglist updated')

                                if eg_name[0] not in biglist.keys():
                                    biglist[eg_name[0]] = [[ce.item()],
                                                           [b.item()],
                                                           [c.item()],
                                                           [d.item()]]
                                else:
                                    biglist[eg_name[0]][0].append(ce.item())
                                    biglist[eg_name[0]][1].append(b.item())
                                    biglist[eg_name[0]][2].append(c.item())
                                    biglist[eg_name[0]][3].append(d.item())

                        # vis(soft_mask,batch_seg_orig,vis_count,mode='train')
                        # vis_count+=25 # += no of images in vis loop
                        scheduler.step()

                    else:
                        Lv.append(loss.item())
                        Lvd.append(dice.item())
                        Lvc.append(ce.item())
                        lv_avg = sum(Lv) / len(Lv)
                        lvd_avg = sum(Lvd) / len(Lvd)
                        lvc_avg = sum(Lvc) / len(Lvc)
                        #                         Lvdc.append(np.array([a.item(),b.item(),c.item(),d.item()]))
                        #                         Ldc.append(np.array([a.item(),b.item(),c.item(),d.item()]))

                        writer.add_scalar('Val Loss', loss.item(), counter_t)
                        writer.add_scalar('Val CE', ce.item(), counter_t)
                        writer.add_scalar('Val Dice', dice.item(), counter_t)
                        writer.add_scalar('D1v', a.item(), counter_t)
                        writer.add_scalar('D2v', b.item(), counter_t)
                        writer.add_scalar('D3v', c.item(), counter_t)
                        writer.add_scalar('D4', d.item(), counter_t)
                        # vis(soft_mask,batch_seg_orig,vis_count,mode='val')
                        print(save_initial)

                        print('Validation total Loss::::::::::',
                              round(loss.item(), 3))
                        del batch_input, batch_gt_mask, batch_seg_orig, pred_mask

                        print('current n worst val: ', round(loss.item(), 2),
                              worst_val)
                        torch.cuda.empty_cache()
                        if epoch > 88 and epoch % 15 == 0:  # save every 15- for logs- confilicting with down

                            checkpoint = {
                                'epoch': epoch + 1,
                                'moving loss': L,
                                'dice': Ld,
                                'val': Lv,
                                'hotlist': hotlist,
                                'biglist': biglist,
                                'valc': Lvc,
                                'vald': Lvd,
                                'cross el': Lc,
                                'state_dict': model.state_dict(),
                                'optimizer': opt.state_dict()
                            }
                            torch.save(
                                checkpoint,
                                '/home/Drive3/rahul/' + save_initial + '-' +
                                str(round(loss, 2)) + '|' + str(epoch + 1))
                            print(
                                'Saved at 25 : ', save_initial + '-' +
                                str(round(loss, 2)) + '|' + str(epoch + 1))

                        if loss < worst_val:
                            # print('saving --------------------------------------',epoch)
                            top_models = sorted(
                                top_models,
                                key=lambda x: x[1])  # sort maybe not needed

                            checkpoint = {
                                'epoch': epoch + 1,
                                'moving loss': L,
                                'dice': Ld,
                                'val': Lv,
                                'hotlist': hotlist,
                                'biglist': biglist,
                                'valc': Lvc,
                                'vald': Lvd,
                                'cross el': Lc,
                                'state_dict': model.state_dict(),
                                'optimizer': opt.state_dict()
                            }
                            torch.save(
                                checkpoint,
                                '/home/Drive3/rahul/' + save_initial + '-' +
                                str(round(loss, 2)) + '|' + str(epoch + 1))

                            to_be_deleted = '/home/Drive3/rahul/' + save_initial + '-' + str(
                                round(top_models[-1][1], 2)) + '|' + str(
                                    top_models[-1][0])  # ...loss|epoch
                            # print(to_be_deleted)

                            top_models.append((epoch + 1, loss.item()))

                            top_models = sorted(top_models, key=lambda x: x[
                                1])  #sort after addition of new val
                            top_models.pop(-1)

                            print('top_models', top_models)

                            worst_val = top_models[-1][1]
                            if str(
                                    to_be_deleted
                            ) != '/home/Drive3/rahul/' + save_initial + '-' + '10000.0|0':  # first 5 epoch will be saved and no deletion this point
                                os.remove(to_be_deleted)
                                # print('sucess deleted------------------')

                        break
示例#19
0
    def train_pixel(self, pixelarr):
        '''
		define training function for each wavelength pixel to run in parallel
		note we create individual neural network for each pixel
		'''

        # start a timer
        starttime = datetime.now()

        startpix = pixelarr[0]
        stoppix = pixelarr[-1]
        wavestart = self.wavelength[pixelarr][0]
        wavestop = self.wavelength[pixelarr][-1]

        print('Pixels: {0}-{1}, Wave: {2}-{3}, pulling first spectra'.format(
            startpix, stoppix, wavestart, wavestop))
        pullspectra_i = pullspectra(MISTpath=self.MISTpath,
                                    C3Kpath=self.C3Kpath)

        # change labels into old_labels
        old_labels_o = self.labels_o

        # create tensor for labels
        X_train_Tensor = Variable(torch.from_numpy(old_labels_o).type(dtype))

        # pull fluxes at wavelength pixel
        Y_train = np.array(self.spectra[:, pixelarr])
        Y_train_Tensor = Variable(torch.from_numpy(Y_train).type(dtype),
                                  requires_grad=False)

        # determine if user wants to start from old file, or
        # create a new ANN model
        if self.restartfile != False:
            # create a model
            model = readNN(self.restartfile, wavestart, wavestop)

        else:
            # determine the acutal D_out for this batch of pixels
            D_out = len(pixelarr)

            # initialize the model
            model = Net(self.D_in, self.H, D_out)

        # set min and max pars to grid bounds for encoding
        model.xmin = np.array([np.log10(2500.0), -1.0, -4.0, -0.2])
        model.xmax = np.array([np.log10(15000.0), 5.5, 0.5, 0.6])

        # initialize the loss function
        loss_fn = torch.nn.MSELoss(reduction='sum')
        # loss_fn = torch.nn.SmoothL1Loss(size_average=False)
        # loss_fn = torch.nn.KLDivLoss(size_average=False)

        # initialize the optimizer
        learning_rate = 0.01
        # optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        optimizer = torch.optim.Adamax(model.parameters(), lr=learning_rate)

        # initialize the scheduler to adjust the learning rate
        scheduler = StepLR(optimizer, 3, gamma=0.75)
        # scheduler = ReduceLROnPlateau(optimizer,mode='min',factor=0.1)

        print('Pixels: {0}-{1}, Wave: {2}-{3}, Start Training...'.format(
            startpix, stoppix, wavestart, wavestop))

        for epoch_i in range(self.epochs):
            # adjust the optimizer lr
            scheduler.step()
            lr_i = optimizer.param_groups[0]['lr']

            epochtime = datetime.now()

            bestloss = np.inf
            stopcounter = 0

            for t in range(self.niter):
                steptime = datetime.now()

                def closure():
                    # Before the backward pass, use the optimizer object to zero all of the
                    # gradients for the variables it will update (which are the learnable weights
                    # of the model)
                    optimizer.zero_grad()

                    # Forward pass: compute predicted y by passing x to the model.
                    y_pred_train_Tensor = model(X_train_Tensor)

                    # Compute and print loss.
                    loss = loss_fn(y_pred_train_Tensor, Y_train_Tensor)

                    # Backward pass: compute gradient of the loss with respect to model parameters
                    loss.backward()

                    if np.isnan(loss.item()):
                        print(y_pred_train_Tensor)
                        print(Y_train_Tensor)

                    if (t + 1) % 5000 == 0:
                        print(
                            '--> WL: {0:6.2f}-{1:6.2f} -- Pix: {2}-{3} -- Ep: {4} -- St [{5:d}/{6:d}] -- Time/step: {7} -- Train Loss: {8:.7f}'
                            .format(wavestart, wavestop, startpix, stoppix,
                                    epoch_i + 1, t + 1, self.niter,
                                    datetime.now() - steptime, loss.item()))
                        sys.stdout.flush()

                    return loss

                # Calling the step function on an Optimizer makes an update to its parameters
                loss_i = optimizer.step(closure)

                if np.isnan(loss_i.item()):
                    print(loss_i)
                    print('LOSS ARE NANS')
                    raise ValueError

                # first allow it to train 10K steps
                if t > 20000:
                    # check to see if it hits our fit tolerance limit
                    if np.abs(loss_i.item() - bestloss) < 1e-4:
                        stopcounter += 1
                        if stopcounter >= 1000:
                            break
                    else:
                        stopcounter = 0
                    bestloss = loss_i.item()

            # # re-draw spectra for next epoch
            # spectra_o,labels_o,wavelength = pullspectra_i(
            # 	self.numtrain,resolution=self.resolution, waverange=self.waverange,
            # 	MISTweighting=True,excludelabels=old_labels_o)
            # spectra = spectra_o.T

            spectra_o, labels_o, wavelength = pullspectra_i.pullpixel(
                pixelarr,
                num=self.numtrain,
                resolution=self.resolution,
                waverange=self.waverange,
                MISTweighting=True,
                excludelabels=old_labels_o,
                Teff=self.Teffrange,
                logg=self.loggrange,
                FeH=self.FeHrange,
                aFe=self.aFerange)

            # create X tensor
            X_valid = labels_o
            X_valid_Tensor = Variable(torch.from_numpy(labels_o).type(dtype))

            # pull fluxes at wavelength pixel and create tensor
            # Y_valid = np.array(spectra[pixel_no,:]).T
            Y_valid = spectra_o
            Y_valid_Tensor = Variable(torch.from_numpy(Y_valid).type(dtype),
                                      requires_grad=False)

            # Validation Forward pass: compute predicted y by passing x to the model.
            Y_pred_valid_Tensor = model(X_valid_Tensor)
            Y_pred_valid = Y_pred_valid_Tensor.data.numpy()

            # calculate the residual at each validation label
            valid_residual = np.squeeze(Y_valid.T - Y_pred_valid.T)
            if valid_residual.ndim == 1:
                valid_residual = valid_residual.reshape(1, self.numtrain)

            # check to make sure valid_residual isn't all nan's, if so
            if np.isnan(valid_residual).all():
                print('Found an all NaN validation Tensor')
                print('X_valid: ', np.isnan(X_valid).any())
                print('Y_valid: ', np.isnan(Y_valid).any())
                print('Y_pred_valid: ', np.isnan(Y_pred_valid).any())
                raise ValueError

            # create log of the validation step if user wants
            if self.logepoch:
                with open(
                        self.logdir +
                        '/ValidLog_pix{0}_{1}_wave{2}_{3}_epoch{4}.log'.format(
                            startpix, stoppix, wavestart, wavestop,
                            epoch_i + 1), 'w') as logfile:
                    logfile.write('modnum Teff log(g) [Fe/H] [a/Fe] ')
                    for ww in self.wavelength[pixelarr]:
                        logfile.write('resid_{} '.format(ww))
                    logfile.write('\n')
                    for ii, res in enumerate(valid_residual[0]):
                        logfile.write('{0} '.format(ii + 1))
                        logfile.write(
                            np.array2string(X_valid[ii],
                                            separator=' ',
                                            max_line_width=np.inf).replace(
                                                '[', '').replace(']', ''))
                        logfile.write(' ')
                        logfile.write(
                            np.array2string(valid_residual.T[ii],
                                            separator=' ',
                                            max_line_width=np.inf).replace(
                                                '[', '').replace(']', ''))
                        # logfile.write(' {0}'.format(valid_residual.T[ii]))
                        logfile.write('\n')

                # fig = self.plt.figure()
                # ax = fig.add_subplot(111)
                # # residsize = ((10 *
                # # 	(max(np.abs(valid_residual))-np.abs(valid_residual))/
                # # 	(max(np.abs(valid_residual))-min(np.abs(valid_residual)))
                # # 	)**2.0) + 2.0
                # for ii,res in enumerate(valid_residual[0]):
                # 	residsize = ((150 * np.abs(valid_residual.T[ii]))**2.0) + 2.0
                # 	scsym = ax.scatter(10.0**X_valid.T[0],X_valid.T[1],s=residsize,alpha=0.5)
                # lgnd = ax.legend([scsym,scsym,scsym],
                # 	# ['{0:5.3f}'.format(min(np.abs(valid_residual))),
                # 	#  '{0:5.3f}'.format(np.median(np.abs(valid_residual))),
                # 	#  '{0:5.3f}'.format(max(np.abs(valid_residual)))],
                # 	['0.0','0.5','1.0'],
                # 	 loc='upper left',
                # 	)
                # lgnd.legendHandles[0]._sizes = [2]
                # lgnd.legendHandles[1]._sizes = [202]
                # lgnd.legendHandles[2]._sizes = [402]
                # # ax.invert_yaxis()
                # # ax.invert_xaxis()
                # ax.set_xlim(16000,3000)
                # ax.set_ylim(6,-1.5)
                # ax.set_xlabel('Teff')
                # ax.set_ylabel('log(g)')
                # fig.savefig(
                # 	self.pdfdir+'/ValidLog_pix{0}_{1}_wave{2}_{3}_epoch{4}.png'.format(
                # 		startpix,stoppix,wavestart,wavestop,epoch_i+1),fmt='PNG',dpi=128)
                # self.plt.close(fig)

            # check if user wants to do adaptive training
            if self.adaptivetrain:
                # sort validation labels on abs(resid)
                ind = np.argsort(np.amax(np.abs(valid_residual), axis=0))

                # determine worst 1% of validation set
                numbadmod = int(0.01 * self.numtrain)
                # if number of bad models < 5, then by default set it to 5
                if numbadmod < 5:
                    numbadmod = 5
                ind_s = ind[-numbadmod:]
                labels_a = labels_o[ind_s]

                # determine the number of models to add per new point
                numselmod = int(0.1 * self.numtrain)
                # if number of new models < 5, then by default set it to 5
                if numselmod < 5:
                    numselmod = 5

                numaddmod = numselmod / numbadmod
                if numaddmod <= 1:
                    numaddmod = 2

                # make floor of numaddmod == 1
                # if numaddmod == 0:
                # 	numaddmod = 1

                # cycle through worst samples, adding 10% new models to training set
                for label_i in labels_a:
                    nosel = 1
                    epsilon = 0.1
                    newlabelbool = False
                    while True:
                        newlabels = np.array([
                            x + epsilon * np.random.randn(int(numaddmod))
                            for x in label_i
                        ]).T
                        labels_check = pullspectra_i.checklabels(newlabels)
                        # check to make sure labels_ai are unique
                        if all([
                                x_ai not in labels_o.tolist()
                                for x_ai in labels_check.tolist()
                        ]):
                            # print('Pixel: {0}, nosel = {1}'.format(pixel_no+1,nosel))
                            newlabelbool = True
                            break
                        # elif (nosel % 100 == 0):
                        # 	print('Pixel: {0}, increasing epsilon to {1} at nosel={2}'.format(pixel_no+1,epsilon*3.0,nosel))
                        # 	epsilon = epsilon*3.0
                        # 	nosel += 1
                        elif (nosel == 100):
                            # print('Pixel: {0}, could not find new model at nosel={1}, quitting'.format(pixel_no+1,nosel))
                            # print(newlabels)
                            break
                        else:
                            nosel += 1
                    """
					if newlabelbool:
						spectra_ai,labels_ai,wavelength = pullspectra_i.selspectra(
							newlabels,
							resolution=self.resolution, 
							waverange=self.waverange,
							)
						Y_valid_a = np.array(spectra_ai.T[pixel_no,:]).T
					"""
                    if newlabelbool:
                        spectra_ai, labels_ai, wavelength = pullspectra_i.pullpixel(
                            pixelarr,
                            inlabels=newlabels,
                            resolution=self.resolution,
                            waverange=self.waverange,
                        )
                        Y_valid_a = spectra_ai
                        Y_valid = np.vstack([Y_valid, Y_valid_a])
                        labels_o = np.append(labels_o, labels_ai, axis=0)
                X_valid_Tensor = Variable(
                    torch.from_numpy(labels_o).type(dtype))
                Y_valid_Tensor = Variable(
                    torch.from_numpy(Y_valid).type(dtype), requires_grad=False)

            # re-use validation set as new training set for the next epoch
            old_labels_o = labels_o
            X_train_Tensor = X_valid_Tensor
            Y_train_Tensor = Y_valid_Tensor

            print(
                'Eph [{4:d}/{5:d}] -- WL: {0:.5f}-{1:.5f} -- Pix: {2}-{3} -- Step Time: {6}, LR: {7:.5f}, Valid max(|Res|): {8:.5f}'
                .format(wavestart, wavestop, startpix, stoppix, epoch_i + 1,
                        self.epochs,
                        datetime.now() - epochtime, lr_i,
                        np.nanmax(np.abs(valid_residual))))
            sys.stdout.flush()

        print('Trained pixel: {0}-{1}/{2} (wavelength: {3}-{4}), took: {5}'.
              format(startpix, stoppix, len(self.spectra[0, :]), wavestart,
                     wavestop,
                     datetime.now() - starttime))
        sys.stdout.flush()

        return [pixelarr, model, optimizer, datetime.now() - starttime]
示例#20
0
def main():
    best_checkpoint = os.path.join(log_dir,
                                   'checkpoints/best_self_trained.pkl')

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_transforms = [
        transforms.RandomResizedCrop(224, scale=(0.7, 1.)),
        transforms.RandomGrayscale(p=0.5),
        transforms.ColorJitter(0.5, 0.5, 0.5, 0.5),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(), normalize
    ]

    train_transforms = transforms.Compose(train_transforms)

    val_transforms = [
        transforms.Resize((224, 224)),
        transforms.ToTensor(), normalize
    ]

    val_transforms = transforms.Compose(val_transforms)

    if not args.eval:

        train_dataset = Foundation_Type_Binary(
            args.train_data,
            transform=train_transforms,
            mask_buildings=args.mask_buildings,
            load_masks=True)

        val_dataset = Foundation_Type_Binary(
            args.val_data,
            transform=val_transforms,
            mask_buildings=args.mask_buildings,
            load_masks=True)

        train_weights = np.array(train_dataset.train_weights)
        train_sampler = torch.utils.data.WeightedRandomSampler(
            train_weights, len(train_weights), replacement=True)

        train_loader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=args.workers,
                                  drop_last=False,
                                  sampler=train_sampler)
        val_loader = DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=args.workers,
                                drop_last=False)

    else:
        test_dataset = Foundation_Type_Binary(
            args.test_data,
            transform=val_transforms,
            mask_buildings=args.mask_buildings,
            load_masks=True)
        test_loader = DataLoader(test_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.workers,
                                 drop_last=False)

    model = resnet50(low_dim=1)

    # Freeze all layers apart from the final layer
    if args.freeze_layers:
        ct = 0
        for child in model.children():
            ct += 1
            if ct < 10:
                print('Freezing {}'.format(child))
                for param in child.parameters():
                    param.requires_grad = False

    summary_writer.add_text('Architecture', model.__class__.__name__)
    summary_writer.add_text('Train Transforms', str(train_transforms))
    summary_writer.add_text('Val Transforms', str(val_transforms))

    criterion = nn.BCEWithLogitsLoss()

    optimizer = RAdam(model.parameters(), lr=lr)
    scheduler = StepLR(optimizer, step_size=1, gamma=0.9)

    summary_writer.add_text('Criterion', str(criterion))
    summary_writer.add_text('Optimizer', str(optimizer))

    is_train = not args.eval

    best_perf = 1e6

    model = nn.DataParallel(model).to(device)
    if args.pretrained:
        try:
            state_dict = torch.load(args.checkpoint)['state_dict']
        except KeyError:  # Format of NPID checkpoints, and checkpoints created with train.py differ
            state_dict = torch.load(args.checkpoint)

        missing, unexpected = model.load_state_dict(state_dict, strict=False)
        if len(missing) or len(unexpected):
            print('Missing or unexpected keys: {},{}'.format(
                missing, unexpected))

    if is_train is True:
        for epoch in range(args.start_epoch, args.epochs):
            model.train()

            print('Training epoch: {}'.format(epoch))
            y_train_pred, y_train_gt, avg_train_loss = parse(
                model, train_loader, criterion, optimizer, 'train', epoch)
            scheduler.step()
            evaluate(summary_writer, 'Train', y_train_gt, y_train_pred,
                     avg_train_loss, train_loader.dataset.classes, epoch)

            print('Validation epoch: {}'.format(epoch))
            with torch.no_grad():
                y_val_pred, y_val_gt, avg_val_loss = parse(
                    model, val_loader, criterion, None, 'val', epoch)

            current_perf = evaluate(summary_writer, 'Val', y_val_gt,
                                    y_val_pred, avg_val_loss,
                                    train_loader.dataset.classes, epoch)

            if current_perf > best_perf:
                best_perf = current_perf
                print('current best performance measure,', best_perf)
                torch.save(model.state_dict(), best_checkpoint)

            # Save regular checkpoint every epoch
            latest_model_path = os.path.join(
                log_dir, 'checkpoint_epoch_{}.pkl'.format(epoch))
            torch.save(model.state_dict(), latest_model_path)
        print('best performance measure mse' + str(best_perf))

    else:
        print('Only test mode:')
        with torch.no_grad():
            y_val_pred, y_val_gt, avg_val_loss = parse(model, test_loader,
                                                       criterion, None, 'test',
                                                       0)

        current_perf = evaluate(summary_writer, 'Test', y_val_gt, y_val_pred,
                                avg_val_loss, test_loader.dataset.classes, 0)

        print('F1: {}'.format(current_perf[2]))
        print('Precision: {}'.format(current_perf[0]))
        print('Recall: {}'.format(current_perf[1]))
        exit()
 def configure_optimizers(
         self) -> Tuple[List[Optimizer], List[_LRScheduler]]:
     optimizer = Adam(self.model.parameters(), self.lr)
     scheduler = StepLR(optimizer, step_size=15, gamma=0.1)
     return [optimizer], [scheduler]
def main():
    # Step 1: init data folders
    print("init data folders")
    # init character folders for dataset construction
    metatrain_folders,metatest_folders = tg.ucf101_folders()

    # Step 2: init neural networks
    print("init neural networks")

    feature_encoder = CNNEncoder()
    relation_network = RelationNetwork(FEATURE_DIM,RELATION_DIM)
    feature_encoder = nn.DataParallel(feature_encoder)
    relation_network = nn.DataParallel(relation_network)


    feature_encoder.cuda(GPU)
    relation_network.cuda(GPU)

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,step_size=100000,gamma=0.5)
    relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,step_size=100000,gamma=0.5)

    if os.path.exists(str("./model/ucf_feature_encoder_c3d_8frame" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")):
        feature_encoder.load_state_dict(torch.load(str("./model/ucf_feature_encoder_c3d_8frame" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")))
        print("load feature encoder success")
    if os.path.exists(str("./model/ucf_relation_network_c3d_8frame"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")):
        relation_network.load_state_dict(torch.load(str("./model/ucf_relation_network_c3d_8frame"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")))
        print("load relation network success")

    total_accuracy = 0.0
    for episode in range(EPISODE):


            # test
            print("Testing...")

            accuracies = []
            for i in range(TEST_EPISODE):
                total_rewards = 0
                task = tg.Ucf101Task(metatest_folders,CLASS_NUM,SAMPLE_NUM_PER_CLASS,15)
                sample_dataloader = tg.get_ucf101_data_loader(task,num_per_class=SAMPLE_NUM_PER_CLASS,split="train",shuffle=False)
                num_per_class = 5
                test_dataloader = tg.get_ucf101_data_loader(task,num_per_class=num_per_class,split="test",shuffle=False)

                sample_images,sample_labels = sample_dataloader.__iter__().next()
                for test_images,test_labels in test_dataloader:
                    batch_size = test_labels.shape[0]
                    # calculate features
                    sample_features = feature_encoder(Variable(sample_images).cuda(GPU)) # 5x64
                    sample_features = sample_features.view(CLASS_NUM,SAMPLE_NUM_PER_CLASS,FEATURE_DIM,19,19)
                    sample_features = torch.sum(sample_features,1).squeeze(1)
                    test_features = feature_encoder(Variable(test_images).cuda(GPU)) # 20x64

                    # calculate relations
                    # each batch sample link to every samples to calculate relations
                    # to form a 100x128 matrix for relation network
                    sample_features_ext = sample_features.unsqueeze(0).repeat(batch_size,1,1,1,1,1)
                    sample_features_ext = torch.squeeze(sample_features_ext)
                    test_features_ext = test_features.unsqueeze(0).repeat(1*CLASS_NUM,1,1,1,1,1)
                    test_features_ext = torch.transpose(test_features_ext,0,1)
                    test_features_ext = torch.squeeze(test_features_ext)
                    relation_pairs = torch.cat((sample_features_ext,test_features_ext),2).view(-1,FEATURE_DIM*2,19,19)
                    relations = relation_network(relation_pairs).view(-1,CLASS_NUM)

                    _,predict_labels = torch.max(relations.data,1)

                    rewards = [1 if predict_labels[j].cuda()==test_labels[j].cuda() else 0 for j in range(batch_size)]

                    total_rewards += np.sum(rewards)


                accuracy = total_rewards/1.0/CLASS_NUM/15
                accuracies.append(accuracy)

            test_accuracy,h = mean_confidence_interval(accuracies)

            print("test accuracy:",test_accuracy,"h:",h)

            total_accuracy += test_accuracy

    print("aver_accuracy:",total_accuracy/EPISODE)
示例#23
0
    if cuda:
        model = model.cuda()

    if mode == 'train':
        # define loss function
        loss_fn = nn.MSELoss()
        if cuda:
            loss_fn = loss_fn.cuda()

        # set optimizer
        optimizer = Adam(
            [param for param in model.parameters() if param.requires_grad],
            lr=base_lr,
            weight_decay=1e-4)
        # learning decay
        scheduler = StepLR(optimizer, step_size=40, gamma=0.1)

        # get data loader
        train_dataloader, _ = data_loader(root=DATASET_PATH,
                                          phase='train',
                                          batch_size=batch,
                                          max_vector=100)
        validate_dataloader, validate_label_file = data_loader(
            root=DATASET_PATH,
            phase='validate',
            batch_size=batch,
            max_vector=100)
        time_ = datetime.datetime.now()
        num_batches = len(train_dataloader)
        #print("num batches : ", num_batches)
示例#24
0
    # relevant to Transfer learning with fixed features

    if (model_params['per_layer_rates']):
        optimizer = train_control['optimizer'](
            [{
                'params': model.get_params_layer(i),
                'lr': model.get_lr_layer(i)
            } for i in range(1, 7)], **optimizer_params)
    else:
        optimizer = train_control['optimizer'](filter(
            lambda p: p.requires_grad, model.parameters()), **optimizer_params)

    # Initiate Scheduler

    if (train_control['lr_scheduler_type'] == 'step'):
        scheduler = StepLR(optimizer, **train_control['step_scheduler_args'])
    elif (train_control['lr_scheduler_type'] == 'exp'):
        scheduler = ExponentialLR(optimizer,
                                  **train_control['exp_scheduler_args'])
    elif (train_control['lr_scheduler_type'] == 'plateau'):
        scheduler = ReduceLROnPlateau(
            optimizer, **train_control['plateau_scheduler_args'])
    else:
        scheduler = StepLR(optimizer, step_size=100, gamma=1)

    if model_params['pytorch_device'] == 'gpu':
        with torch.cuda.device(model_params['cuda_device']):
            model_trainer = ModelTrainer(model,
                                         train_dataset_loader,
                                         valid_dataset_loader,
                                         test_dataset_loader,
示例#25
0
def main(epochs, lr, optimizer_type, es_param):
    """Train the SO3 model on MNIST for fun, report performance metrics.

    Training for 5 epochs with SGD and a learning rate of .01 yielded:
    Accuracy:
    Precisions:
    Recalls:

    Not bad for a tiny model first pass! Would be cool to look at the clustering
    """

    # Load MNIST data
    transform_group = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])

    # Train
    mnist_train = datasets.MNIST('data/mnist',
                                 train=True,
                                 download=True,
                                 transform=transform_group)
    # split into train val
    train, val = torch.utils.data.random_split(mnist_train, [50000, 10000])
    train_loader = torch.utils.data.DataLoader(train, batch_size=1)
    val_loader = torch.utils.data.DataLoader(val, batch_size=1)

    # No validation yet, cause lazy and just getting something running

    # Test
    mnist_test = datasets.MNIST('data/mnist',
                                train=False,
                                download=True,
                                transform=transform_group)
    test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=1)

    # Set device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Initialize network
    model = SO3Classifier().to(device)

    # Training utilities
    optimizer = optimizers[optimizer_type](model.parameters(), lr=lr)
    scheduler = StepLR(optimizer, step_size=1)

    # Train loop
    train_losses = []
    val_losses = []
    for epoch in range(1, epochs + 1):
        # train
        epoch_train_losses = train_epoch(model, device, train_loader,
                                         optimizer, epoch)
        scheduler.step()
        train_losses.extend(epoch_train_losses)
        # val
        epoch_val_losses = val_epoch(model, device, val_loader, epoch)
        val_losses.extend(epoch_val_losses)

        # Basic early stopping
        increasing_loss = [(val_losses[-(i + 1)] > val_losses[-(es_param + 1)])
                           for i in range(es_param)]

        # If the mean loss from patience (*100) batches ago is smaller
        # then every mean loss since (e.g. if val loss starts monotonically increasing from overfitting)
        if False not in increasing_loss:
            print("Early stopping activated.")
            break

    # Save mean train losses
    with open(f'data/training_logs/train_losses.pkl', 'wb') as file:
        pickle.dump(train_losses, file)
    # Save mean val losses
    with open(f'data/training_logs/val_losses.pkl', 'wb') as file:
        pickle.dump(val_losses, file)

    # One test loop on test data
    test_epoch(model, test_loader, device)
示例#26
0
class Trainer():
    def __init__(self, dataloader, cfg_data, pwd):

        self.cfg_data = cfg_data

        self.data_mode = cfg.DATASET
        self.exp_name = cfg.EXP_NAME
        self.exp_path = cfg.EXP_PATH
        self.pwd = pwd

        self.net_name = cfg.NET

        if self.net_name in ['SANet']:
            loss_1_fn = torch.nn.MSELoss()
            from misc import pytorch_ssim
            loss_2_fn = pytorch_ssim.SSIM(window_size=11)

        self.net = CrowdCounter(cfg.GPU_ID,self.net_name,loss_1_fn,loss_2_fn).cuda()
        self.optimizer = optim.Adam(self.net.CCN.parameters(), lr=cfg.LR, weight_decay=1e-4)
        # self.optimizer = optim.SGD(self.net.parameters(), cfg.LR, momentum=0.95,weight_decay=5e-4)
        self.scheduler = StepLR(self.optimizer, step_size=cfg.NUM_EPOCH_LR_DECAY, gamma=cfg.LR_DECAY)          

        self.train_record = {'best_mae': 1e20, 'best_mse':1e20, 'best_model_name': ''}
        self.timer = {'iter time' : Timer(),'train time' : Timer(),'val time' : Timer()} 


        self.epoch = 0
        self.i_tb = 0
        
        if cfg.PRE_GCC:
            self.net.load_state_dict(torch.load(cfg.PRE_GCC_MODEL))

        self.train_loader, self.val_loader, self.restore_transform = dataloader()

        if cfg.RESUME:
            latest_state = torch.load(cfg.RESUME_PATH)
            self.net.load_state_dict(latest_state['net'])
            self.optimizer.load_state_dict(latest_state['optimizer'])
            self.scheduler.load_state_dict(latest_state['scheduler'])
            self.epoch = latest_state['epoch'] + 1
            self.i_tb = latest_state['i_tb']
            self.train_record = latest_state['train_record']
            self.exp_path = latest_state['exp_path']
            self.exp_name = latest_state['exp_name']

        self.writer, self.log_txt = logger(self.exp_path, self.exp_name, self.pwd, 'exp', resume=cfg.RESUME)


    def forward(self):

        # self.validate_V3()
        for epoch in range(self.epoch, cfg.MAX_EPOCH):
            self.epoch = epoch
            if epoch > cfg.LR_DECAY_START:
                self.scheduler.step()
                
            # training    
            self.timer['train time'].tic()
            self.train()
            self.timer['train time'].toc(average=False)

            print( 'train time: {:.2f}s'.format(self.timer['train time'].diff) )
            print( '='*20 )

            # validation
            if epoch%cfg.VAL_FREQ==0 or epoch>cfg.VAL_DENSE_START:
                self.timer['val time'].tic()
                if self.data_mode in ['SHHA', 'SHHB', 'QNRF', 'UCF50','Venice','Venezia_cc']:
                    self.validate_V1()
                elif self.data_mode == 'WE':
                    self.validate_V2()
                elif self.data_mode == 'GCC':
                    self.validate_V3()
                self.timer['val time'].toc(average=False)
                print( 'val time: {:.2f}s'.format(self.timer['val time'].diff) )


    def train(self): # training for all datasets
        self.net.train()
        for i, data in enumerate(self.train_loader, 0):
            self.timer['iter time'].tic()
            img, gt_map = data
            img = Variable(img).cuda()
            gt_map = Variable(gt_map).cuda()

            self.optimizer.zero_grad()
            pred_map = self.net(img, gt_map)
            loss1,loss2 = self.net.loss
            loss = loss1+loss2
            loss.backward()
            self.optimizer.step()

            if (i + 1) % cfg.PRINT_FREQ == 0:
                self.i_tb += 1
                self.writer.add_scalar('train_loss', loss.item(), self.i_tb)
                self.writer.add_scalar('train_loss1', loss1.item(), self.i_tb)
                self.writer.add_scalar('train_loss2', loss2.item(), self.i_tb)
                self.timer['iter time'].toc(average=False)
                print( '[ep %d][it %d][loss %.4f][lr %.4f][%.2fs]' % \
                        (self.epoch + 1, i + 1, loss.item(), self.optimizer.param_groups[0]['lr']*10000, self.timer['iter time'].diff) )
                print( '        [cnt: gt: %.1f pred: %.2f]' % (gt_map[0].sum().data/self.cfg_data.LOG_PARA, pred_map[0].sum().data/self.cfg_data.LOG_PARA) )





    def validate_V1(self):# validate_V1 for SHHA, SHHB, UCF-QNRF, UCF50

        self.net.eval()
        
        losses = AverageMeter()
        maes = AverageMeter()
        mses = AverageMeter()

        for vi, data in enumerate(self.val_loader, 0):
            img, gt_map = data


            with torch.no_grad():
                img = Variable(img).cuda()
                gt_map = Variable(gt_map).cuda()

                pred_map = self.net.forward(img,gt_map)

                pred_map = pred_map.data.cpu().numpy()
                gt_map = gt_map.data.cpu().numpy()
                
                for i_img in range(pred_map.shape[0]):
                    pred_cnt = np.sum(pred_map[i_img])/self.cfg_data.LOG_PARA
                    gt_count = np.sum(gt_map[i_img])/self.cfg_data.LOG_PARA

                    loss1,loss2 = self.net.loss
                    loss = loss1.item()+loss2.item()
                    losses.update(loss)
                    maes.update(abs(gt_count-pred_cnt))
                    mses.update((gt_count-pred_cnt)*(gt_count-pred_cnt))
                if vi==0:
                    vis_results(self.exp_name, self.epoch, self.writer, self.restore_transform, img, pred_map, gt_map)
            
        mae = maes.avg
        mse = np.sqrt(mses.avg)
        loss = losses.avg

        self.writer.add_scalar('val_loss', loss, self.epoch + 1)
        self.writer.add_scalar('mae', mae, self.epoch + 1)
        self.writer.add_scalar('mse', mse, self.epoch + 1)

        self.train_record = update_model(self.net,self.optimizer,self.scheduler,self.epoch,self.i_tb,self.exp_path,self.exp_name, \
            [mae, mse, loss],self.train_record,self.log_txt)
        print_summary(self.exp_name,[mae, mse, loss],self.train_record)


    def validate_V2(self):# validate_V2 for WE

        self.net.eval()

        losses = AverageCategoryMeter(5)
        maes = AverageCategoryMeter(5)

        
        for i_sub,i_loader in enumerate(self.val_loader,0):

            for vi, data in enumerate(i_loader, 0):
                img, gt_map = data

                with torch.no_grad():
                    img = Variable(img).cuda()
                    gt_map = Variable(gt_map).cuda()

                    pred_map = self.net.forward(img,gt_map)

                    pred_map = pred_map.data.cpu().numpy()
                    gt_map = gt_map.data.cpu().numpy()
                    
                    
                    for i_img in range(pred_map.shape[0]):
                        pred_cnt = np.sum(pred_map[i_img])/self.cfg_data.LOG_PARA
                        gt_count = np.sum(gt_map[i_img])/self.cfg_data.LOG_PARA

                        losses.update(self.net.loss.item(),i_sub)
                        maes.update(abs(gt_count-pred_cnt),i_sub)
                    if vi==0:
                        vis_results(self.exp_name, self.epoch, self.writer, self.restore_transform, img, pred_map, gt_map)
            
        mae = np.average(maes.avg)
        loss = np.average(losses.avg)

        self.writer.add_scalar('val_loss', loss, self.epoch + 1)
        self.writer.add_scalar('mae', mae, self.epoch + 1)

        self.train_record = update_model(self.net,self.optimizer,self.scheduler,self.epoch,self.i_tb,self.exp_path,self.exp_name, \
            [mae, 0, loss],self.train_record,self.log_txt)
        print_summary(self.exp_name,[mae, 0, loss],self.train_record)


    def validate_V3(self):# validate_V3 for GCC

        self.net.eval()
        
        losses = AverageMeter()
        maes = AverageMeter()
        mses = AverageMeter()

        c_maes = {'level':AverageCategoryMeter(9), 'time':AverageCategoryMeter(8),'weather':AverageCategoryMeter(7)}
        c_mses = {'level':AverageCategoryMeter(9), 'time':AverageCategoryMeter(8),'weather':AverageCategoryMeter(7)}


        for vi, data in enumerate(self.val_loader, 0):
            img, gt_map, attributes_pt = data

            with torch.no_grad():
                img = Variable(img).cuda()
                gt_map = Variable(gt_map).cuda()


                pred_map = self.net.forward(img,gt_map)

                pred_map = pred_map.data.cpu().numpy()
                gt_map = gt_map.data.cpu().numpy()
                
                for i_img in range(pred_map.shape[0]):
                    pred_cnt = np.sum(pred_map)/self.cfg_data.LOG_PARA
                    gt_count = np.sum(gt_map)/self.cfg_data.LOG_PARA

                    s_mae = abs(gt_count-pred_cnt)
                    s_mse = (gt_count-pred_cnt)*(gt_count-pred_cnt)

                    loss1,loss2 = self.net.loss
                    loss = loss1.item()+loss2.item()
                    losses.update(loss)
                    maes.update(s_mae)
                    mses.update(s_mse)   
                    attributes_pt = attributes_pt.squeeze() 
                    c_maes['level'].update(s_mae,attributes_pt[0])
                    c_mses['level'].update(s_mse,attributes_pt[0])
                    c_maes['time'].update(s_mae,attributes_pt[1]/3)
                    c_mses['time'].update(s_mse,attributes_pt[1]/3)
                    c_maes['weather'].update(s_mae,attributes_pt[2])
                    c_mses['weather'].update(s_mse,attributes_pt[2])


                if vi==0:
                    vis_results(self.exp_name, self.epoch, self.writer, self.restore_transform, img, pred_map, gt_map)
            
        loss = losses.avg
        mae = maes.avg
        mse = np.sqrt(mses.avg)


        self.writer.add_scalar('val_loss', loss, self.epoch + 1)
        self.writer.add_scalar('mae', mae, self.epoch + 1)
        self.writer.add_scalar('mse', mse, self.epoch + 1)

        self.train_record = update_model(self.net,self.optimizer,self.scheduler,self.epoch,self.i_tb,self.exp_path,self.exp_name, \
            [mae, mse, loss],self.train_record,self.log_txt)

        print_GCC_summary(self.log_txt,self.epoch,[mae, mse, loss],self.train_record,c_maes,c_mses)
示例#27
0
def main(cuda, batch_size, pretrain_epochs, finetune_epochs, testing_mode):
    writer = SummaryWriter()  # create the TensorBoard object

    # callback function to call during training, uses writer from the scope

    def training_callback(epoch, lr, loss, validation_loss):
        writer.add_scalars(
            "data/autoencoder",
            {
                "lr": lr,
                "loss": loss,
                "validation_loss": validation_loss,
            },
            epoch,
        )

    ds_train = CachedMNIST(train=True, cuda=cuda,
                           testing_mode=testing_mode)  # training dataset
    ds_val = CachedMNIST(train=False, cuda=cuda,
                         testing_mode=testing_mode)  # evaluation dataset
    autoencoder = StackedDenoisingAutoEncoder([28 * 28, 500, 500, 2000, 10],
                                              final_activation=None)
    if cuda:
        autoencoder.cuda()
    print("Pretraining stage.")
    ae.pretrain(
        ds_train,
        autoencoder,
        cuda=cuda,
        validation=ds_val,
        epochs=pretrain_epochs,
        batch_size=batch_size,
        optimizer=lambda model: SGD(model.parameters(), lr=0.1, momentum=0.9),
        scheduler=lambda x: StepLR(x, 100, gamma=0.1),
        corruption=0.2,
    )
    print("Training stage.")
    ae_optimizer = SGD(params=autoencoder.parameters(), lr=0.1, momentum=0.9)
    ae.train(
        ds_train,
        autoencoder,
        cuda=cuda,
        validation=ds_val,
        epochs=finetune_epochs,
        batch_size=batch_size,
        optimizer=ae_optimizer,
        scheduler=StepLR(ae_optimizer, 100, gamma=0.1),
        corruption=0.2,
        update_callback=training_callback,
    )
    print("DEC stage.")
    model = DEC(cluster_number=10,
                hidden_dimension=10,
                encoder=autoencoder.encoder)
    if cuda:
        model.cuda()
    dec_optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)
    train(
        dataset=ds_train,
        model=model,
        epochs=100,
        batch_size=256,
        optimizer=dec_optimizer,
        stopping_delta=0.000001,
        cuda=cuda,
    )
    predicted, actual = predict(ds_train,
                                model,
                                1024,
                                silent=True,
                                return_actual=True,
                                cuda=cuda)
    actual = actual.cpu().numpy()
    predicted = predicted.cpu().numpy()
    reassignment, accuracy = cluster_accuracy(actual, predicted)
    print("Final DEC accuracy: %s" % accuracy)
    if not testing_mode:
        predicted_reassigned = [reassignment[item]
                                for item in predicted]  # TODO numpify
        confusion = confusion_matrix(actual, predicted_reassigned)
        normalised_confusion = (confusion.astype("float") /
                                confusion.sum(axis=1)[:, np.newaxis])
        confusion_id = uuid.uuid4().hex
        sns.heatmap(normalised_confusion).get_figure().savefig(
            "confusion_%s.png" % confusion_id)
        print("Writing out confusion diagram with UUID: %s" % confusion_id)
        writer.close()
示例#28
0
if 'efficientnet' in args.modelname:
    net = EfficientNet.from_pretrained(args.modelname, num_classes=1)
elif 'resnet34' in args.modelname:
    net = torchvision.models.resnet34(pretrained=True)
    net.fc = nn.Linear(in_features=net.fc.in_features, out_features=1)
elif 'resnet50' in args.modelname:
    net = torchvision.models.resnet50(pretrained=True)
    net.fc = nn.Linear(in_features=net.fc.in_features, out_features=1)

# Optimizer
optimizer = None
if args.optimizer == 'adam':
    optimizer = optim.Adam(net.parameters(), lr=args.learningrate)
elif args.optimizer == 'radam':
    optimizer = RAdam(net.parameters(), lr=args.learningrate)
elif args.optimizer == 'sgd':
    optimizer = optim.SGD(net.parameters(), lr=args.learningrate, momentum=0.9, weight_decay=0.)

# Scheduler
scheduler = None
if 'step' in args.scheduler:
    scheduler = StepLR(optimizer, step_size=5, gamma=0.5)
elif 'exp' in args.scheduler:
    scheduler = ExponentialLR(optimizer, gamma=0.95)
elif 'cycle' in args.scheduler:
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=1, eta_min=args.learningrate*0.1)

# Train  #########################################################################
train_model(dataloaders, net, device, optimizer, scheduler, batch_num, num_epochs=epoch, exp=exp)

示例#29
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch Voice Filter')
    parser.add_argument('-b', '--base_dir', type=str, default='.',
                        help="Root directory of run.")
    parser.add_argument('--checkpoint_path', type=str, default=None,
                        help='Path to last checkpoint')
    parser.add_argument('-e', '--embedder_path', type=str, required=True,
                        help="path of embedder model pt file")
    parser.add_argument('-m', '--model', type=str, required=True,
                        help="Name of the model. Used for both logging and saving checkpoints.")
    args = parser.parse_args()

    chkpt_path = args.checkpoint_path if args.checkpoint_path is not None else None

    pt_dir = os.path.join(args.base_dir, config.log['chkpt_dir'], args.model)
    os.makedirs(pt_dir, exist_ok=True)

    log_dir = os.path.join(args.base_dir, config.log['log_dir'], args.model)
    os.makedirs(log_dir, exist_ok=True)

    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(os.path.join(log_dir,
                '%s-%d.log' % (args.model, time.time()))),
            logging.StreamHandler()
        ]
    )
    logger = logging.getLogger()
    writer = MyWriter(log_dir)

    trainloader = create_dataloader(train=True)
    testloader = create_dataloader(train=False)

    embedder_pt = torch.load(args.embedder_path)
    embedder = SpeechEmbedder().cuda()
    embedder.load_state_dict(embedder_pt)
    embedder.eval()

    model = nn.DataParallel(VoiceFilter())
    optimizer = torch.optim.Adam(model.parameters(),lr=config.train['adam'])
    audio = Audio()
    
    starting_epoch = 1

    if chkpt_path is not None:
        logger.info("Resuming from checkpoint: %s" % chkpt_path)
        checkpoint_file = torch.load(chkpt_path)
        model.load_state_dict(checkpoint_file['model'])
        optimizer.load_state_dict(checkpoint_file['optimizer'])
        starting_epoch = checkpoint_file['epoch']
    else:
        logger.info("Starting new training run")

    scheduler = StepLR(optimizer, step_size=1, gamma=0.98)
    for epoch in range(starting_epoch, config.train['epoch'] + 1):
        train(embedder,model,optimizer,trainloader,writer,logger,epoch,pt_dir)
        logger.info("Starting to validate epoch...")
        validate(audio,model,embedder,testloader,writer,epoch)
        scheduler.step()

    model_saver(model,optimizer,pt_dir,config.train['epoch'])
示例#30
0
    def __init__(self, args):
        self.reconstruction_path = args.reconstruction_path
        if not os.path.exists(self.reconstruction_path):
            os.makedirs(self.reconstruction_path)

        self.beta = args.beta
        self.train_batch_size = args.train_batch_size
        self.test_batch_size = args.test_batch_size
        self.epochs = args.epochs
        self.early_stop = args.early_stop
        self.early_stop_observation_period = args.early_stop_observation_period
        self.use_scheduler = False
        self.print_training = args.print_training
        self.class_num = args.class_num
        self.disentangle_with_reparameterization = args.disentangle_with_reparameterization

        self.z_dim = args.z_dim
        self.disc_input_dim = int(self.z_dim / 2)
        self.class_idx = range(0, self.disc_input_dim)
        self.membership_idx = range(self.disc_input_dim, self.z_dim)

        self.nets = dict()

        if args.dataset in ['MNIST', 'Fashion-MNIST', 'CIFAR-10', 'SVHN']:
            if args.dataset in ['MNIST', 'Fashion-MNIST']:
                self.num_channels = 1
            elif args.dataset in ['CIFAR-10', 'SVHN']:
                self.num_channels = 3

            self.nets['encoder'] = module.VAEConvEncoder(
                self.z_dim, self.num_channels)
            self.nets['decoder'] = module.VAEConvDecoder(
                self.z_dim, self.num_channels)

        elif args.dataset in ['adult', 'location']:
            self.nets['encoder'] = module.VAEFCEncoder(args.encoder_input_dim,
                                                       self.z_dim)
            self.nets['decoder'] = module.FCDecoder(args.encoder_input_dim,
                                                    self.z_dim)

        self.discs = {
            'class_fz':
            module.ClassDiscriminator(self.z_dim, args.class_num),
            'class_cz':
            module.ClassDiscriminator(self.disc_input_dim, args.class_num),
            'class_mz':
            module.ClassDiscriminator(self.disc_input_dim, args.class_num),
            'membership_fz':
            module.MembershipDiscriminator(self.z_dim + args.class_num, 1),
            'membership_cz':
            module.MembershipDiscriminator(
                self.disc_input_dim + args.class_num, 1),
            'membership_mz':
            module.MembershipDiscriminator(
                self.disc_input_dim + args.class_num, 1),
        }

        self.recon_loss = self.get_loss_function()
        self.class_loss = nn.CrossEntropyLoss(reduction='sum')
        self.membership_loss = nn.BCEWithLogitsLoss(reduction='sum')

        # optimizer
        self.optimizer = dict()
        for net_type in self.nets:
            self.optimizer[net_type] = optim.Adam(
                self.nets[net_type].parameters(),
                lr=args.recon_lr,
                betas=(0.5, 0.999))
        self.discriminator_lr = args.disc_lr
        for disc_type in self.discs:
            self.optimizer[disc_type] = optim.Adam(
                self.discs[disc_type].parameters(),
                lr=self.discriminator_lr,
                betas=(0.5, 0.999))

        self.weights = {
            'recon': args.recon_weight,
            'class_cz': args.class_cz_weight,
            'class_mz': args.class_mz_weight,
            'membership_cz': args.membership_cz_weight,
            'membership_mz': args.membership_mz_weight,
        }

        self.scheduler_enc = StepLR(self.optimizer['encoder'],
                                    step_size=50,
                                    gamma=0.1)
        self.scheduler_dec = StepLR(self.optimizer['decoder'],
                                    step_size=50,
                                    gamma=0.1)

        # to device
        self.device = torch.device("cuda:{}".format(args.gpu_id))
        for net_type in self.nets:
            self.nets[net_type] = self.nets[net_type].to(self.device)
        for disc_type in self.discs:
            self.discs[disc_type] = self.discs[disc_type].to(self.device)

        self.disentangle = (
            self.weights['class_cz'] + self.weights['class_mz'] +
            self.weights['membership_cz'] + self.weights['membership_mz'] > 0)

        self.start_epoch = 0
        self.best_valid_loss = float("inf")
        # self.train_loss = 0
        self.early_stop_count = 0

        self.acc_dict = {
            'class_fz': 0,
            'class_cz': 0,
            'class_mz': 0,
            'membership_fz': 0,
            'membership_cz': 0,
            'membership_mz': 0,
        }
        self.best_acc_dict = {}

        if 'cuda' in str(self.device):
            cudnn.benchmark = True

        if args.resume:
            print('==> Resuming from checkpoint..')
            try:
                self.load()
            except FileNotFoundError:
                print(
                    'There is no pre-trained model; Train model from scratch')
示例#31
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr',
                        type=float,
                        default=1.0,
                        metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.7,
                        metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--dry-run',
                        action='store_true',
                        default=False,
                        help='quickly check a single pass')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=100,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    train_kwargs = {'batch_size': args.batch_size}
    test_kwargs = {'batch_size': args.test_batch_size}
    if use_cuda:
        cuda_kwargs = {
            'num_workers': multiprocessing.cpu_count(),
            'shuffle': True
        }
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])
    scriptPath = os.path.dirname(os.path.realpath(__file__))
    dataDir = os.path.join(scriptPath, 'data')
    dataset1 = datasets.MNIST(dataDir,
                              train=True,
                              download=True,
                              transform=transform)
    dataset2 = datasets.MNIST(dataDir, train=False, transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

    model = Net().to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    for epoch in range(1, args.epochs + 1):
        # Start profiling from 2nd epoch
        if epoch == 2:
            torch.cuda.cudart().cudaProfilerStart()

        nvtx.range_push("Epoch " + str(epoch))
        nvtx.range_push("Train")
        train(args, model, device, train_loader, optimizer, epoch)
        nvtx.range_pop()  # Train

        nvtx.range_push("Test")
        test(model, device, test_loader)
        nvtx.range_pop()  # Test

        scheduler.step()
        nvtx.range_pop()  # Epoch
        # Stop profiling at the end of 2nd epoch
        if epoch == 2:
            torch.cuda.cudart().cudaProfilerStop()

    if args.save_model:
        torch.save(model.state_dict(), "mnist_cnn.pt")
示例#32
0
文件: train.py 项目: tyhu/PyAI
def main():
    #batch_size = 500
    batch_size = 128
    img_net = ImgBranch2()
    text_net = TextBranch2()
    #img_net, text_net = torch.load('img_net.pt'), torch.load('text_net.pt')
    tri_loss = TripletLoss(0.1)
    params = list(img_net.parameters())+list(text_net.parameters())
    opt = optim.SGD(params, lr=0.1, momentum=0.9, weight_decay=0.00005)
    scheduler = StepLR(opt, step_size=10, gamma=0.1)
    img_net.cuda()
    text_net.cuda()

    idlst = [l.strip() for l in file('train_ids.txt')]
    img_dir = '/home/datasets/coco/raw/vgg19_feat/train/'
    text_dir = '/home/datasets/coco/raw/annotation_text/hglmm_pca_npy/train/'

    dataset = COCOImgTextFeatPairDataset(idlst,img_dir,text_dir)
    dataloader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    ### Test set ###
    tiidlst = [l.strip() for l in file('test_ids.txt')]
    img_dir = '/home/datasets/coco/raw/vgg19_feat/val/'
    text_dir = '/home/datasets/coco/raw/annotation_text/hglmm_pca_npy/val/'
    img_feat_dataset = COCOImgFeatDataset(tiidlst, img_dir)
    text_feat_dataset = COCOTextFeatDataset(tiidlst,text_dir)

    ### train subset
    triidlst = [l.strip() for l in file('train_val_ids.txt')]
    img_dir = '/home/datasets/coco/raw/vgg19_feat/train/'
    text_dir = '/home/datasets/coco/raw/annotation_text/hglmm_pca_npy/train/'
    img_sub_dataset = COCOImgFeatDataset(triidlst, img_dir)
    text_sub_dataset = COCOTextFeatDataset(triidlst,text_dir)
    

    total_loss = 0
    for eidx in range(50):
        #total_loss = 0
        print 'epoch',eidx
        for i, batch in enumerate(dataloader):
            #anc_i, pos_t, neg_t, anc_t, pos_i, neg_i = hard_negative_sample(batch,img_net,text_net)
            anc_i, pos_t, neg_t, anc_t, pos_i, neg_i = random_sample(batch)
            sub_batch_num = 1
            sub_batch_size = anc_i.shape[0]/sub_batch_num
            for j in range(sub_batch_num):
                start, end = j*sub_batch_size, (j+1)*sub_batch_size
                anc_i_sub, pos_t_sub, neg_t_sub, neg_i_sub = anc_i[start:end], pos_t[start:end], neg_t[start:end], neg_i[start:end]

                #anc_i_sub = img_net(Variable(torch.Tensor(anc_i_sub).cuda()))
                #pos_t_sub = text_net(Variable(torch.Tensor(pos_t_sub).cuda()))
                #neg_t_sub = text_net(Variable(torch.Tensor(neg_t_sub).cuda()))
                #neg_i_sub = img_net(Variable(torch.Tensor(neg_i_sub).cuda()))
                anc_i_sub = img_net(Variable(anc_i_sub.cuda()))
                pos_t_sub = text_net(Variable(pos_t_sub.cuda()))
                neg_t_sub = text_net(Variable(neg_t_sub.cuda()))
                neg_i_sub = img_net(Variable(neg_i_sub.cuda()))
                anc_t_sub = pos_t_sub
                pos_i_sub = anc_i_sub
                loss1 = tri_loss(anc_i_sub, pos_t_sub, neg_t_sub)
                loss2 = tri_loss(anc_t_sub, pos_i_sub, neg_i_sub)
                loss = loss1+2*loss2

                opt.zero_grad()
                loss.backward()
                opt.step()
                total_loss+=loss.data[0]

            if i%200==0:
                print 'epoch',eidx,'batch', i
                test(tiidlst,img_feat_dataset, text_feat_dataset,img_net,text_net)
                print 'train sub'
                test(triidlst,img_sub_dataset, text_sub_dataset,img_net,text_net)
                print 'train loss:',total_loss
                total_loss = 0
            #break
        #break
        scheduler.step()

        ###### TEST ######
        # test(tiidlst,img_feat_dataset, text_feat_dataset,img_net,text_net)

        torch.save(img_net,'img_net.pt')
        torch.save(text_net,'text_net.pt')
model = nn.DataParallel(model)


# train all layers
other_parameters = [param for name, param in model.module.named_parameters() if 'last_linear' not in name]
optimizer = AdamW(
    [
        {"params": model.module.last_linear.parameters(), "lr": 1e-3},
        {"params": other_parameters},
    ], 
    lr=1e-4, weight_decay = 0.01)    
    

best_loss_val = 100 
criterion = CosineMarginCrossEntropy().cuda()
exp_lr_scheduler = StepLR(optimizer, step_size=18, gamma=0.1)
for epoch in range(num_epochs):
    exp_lr_scheduler.step()
   
    
    # train for one epoch
    sample_weights = train(train_loader, model, criterion, optimizer, epoch, sample_weights, neptune_ctx)

    # evaluate on validation set
    acc1, acc5, loss_val = validate(val_loader, model, criterion)
    neptune_ctx.channel_send('val-acc1', acc1)
    neptune_ctx.channel_send('val-acc5', acc5)
    neptune_ctx.channel_send('val-loss', loss_val)
    neptune_ctx.channel_send('lr', float(exp_lr_scheduler.get_lr()[0]))
    
    logger.info(f'Epoch: {epoch} Acc1: {acc1} Acc5: {acc5} Val-Loss: {loss_val}')