Exemple #1
0
class ECGTrainer(object):

    def __init__(self, block_config='small', num_threads=2):
        torch.set_num_threads(num_threads)
        self.n_epochs = 60
        self.batch_size = 128
        self.scheduler = None
        self.num_threads = num_threads
        self.cuda = torch.cuda.is_available()

        if block_config == 'small':
            self.block_config = (3, 6, 12, 8)
        else:
            self.block_config = (6, 12, 24, 16)

        self.__build_model()
        self.__build_criterion()
        self.__build_optimizer()
        self.__build_scheduler()
        return

    def __build_model(self):
        self.model = DenseNet(
            num_classes=55, block_config=self.block_config
        )
        if self.cuda:
            self.model.cuda()
        return

    def __build_criterion(self):
        self.criterion = ComboLoss(
            losses=['mlsml', 'f1', 'focal'], weights=[1, 1, 3]
        )
        return

    def __build_optimizer(self):
        opt_params = {'lr': 1e-3, 'weight_decay': 0.0,
                      'params': self.model.parameters()}
        self.optimizer = AdaBound(amsbound=True, **opt_params)
        return

    def __build_scheduler(self):
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, 'max', factor=0.333, patience=5,
            verbose=True, min_lr=1e-5)
        return

    def run(self, trainset, validset, model_dir):
        print('=' * 100 + '\n' + 'TRAINING MODEL\n' + '-' * 100 + '\n')
        model_path = os.path.join(model_dir, 'model.pth')
        thresh_path = os.path.join(model_dir, 'threshold.npy')

        dataloader = {
            'train': ECGLoader(trainset, self.batch_size, True, self.num_threads).build(),
            'valid': ECGLoader(validset, 64, False, self.num_threads).build()
        }

        best_metric, best_preds = None, None
        for epoch in range(self.n_epochs):
            e_message = '[EPOCH {:0=3d}/{:0=3d}]'.format(epoch + 1, self.n_epochs)

            for phase in ['train', 'valid']:
                ep_message = e_message + '[' + phase.upper() + ']'
                if phase == 'train':
                    self.model.train()
                else:
                    self.model.eval()

                losses, preds, labels = [], [], []
                batch_num = len(dataloader[phase])
                for ith_batch, data in enumerate(dataloader[phase]):
                    ecg, label = [d.cuda() for d in data] if self.cuda else data

                    pred = self.model(ecg)
                    loss = self.criterion(pred, label)
                    if phase == 'train':
                        self.optimizer.zero_grad()
                        loss.backward()
                        self.optimizer.step()

                    pred = torch.sigmoid(pred)
                    pred = pred.data.cpu().numpy()
                    label = label.data.cpu().numpy()

                    bin_pred = np.copy(pred)
                    bin_pred[bin_pred > 0.5] = 1
                    bin_pred[bin_pred <= 0.5] = 0
                    f1 = f1_score(label.flatten(), bin_pred.flatten())

                    losses.append(loss.item())
                    preds.append(pred)
                    labels.append(label)

                    sr_message = '[STEP {:0=3d}/{:0=3d}]-[Loss: {:.6f} F1: {:.6f}]'
                    sr_message = ep_message + sr_message
                    print(sr_message.format(ith_batch + 1, batch_num, loss, f1), end='\r')

                preds = np.concatenate(preds, axis=0)
                labels = np.concatenate(labels, axis=0)
                bin_preds = np.copy(preds)
                bin_preds[bin_preds > 0.5] = 1
                bin_preds[bin_preds <= 0.5] = 0

                avg_loss = np.mean(losses)
                avg_f1 = f1_score(labels.flatten(), bin_preds.flatten())
                er_message = '-----[Loss: {:.6f} F1: {:.6f}]'
                er_message = '\n\033[94m' + ep_message + er_message + '\033[0m'
                print(er_message.format(avg_loss, avg_f1))

                if phase == 'valid':
                    if self.scheduler is not None:
                        self.scheduler.step(avg_f1)
                    if best_metric is None or best_metric < avg_f1:
                        best_metric = avg_f1
                        best_preds = [labels, preds]
                        best_loss_metrics = [epoch + 1, avg_loss, avg_f1]
                        torch.save(self.model.state_dict(), model_path)
                        print('[Best validation metric, model: {}]'.format(model_path))
                    print()

        best_f1, best_th = best_f1_score(*best_preds)
        np.save(thresh_path, np.array(best_th))
        print('[Searched Best F1: {:.6f}]\n'.format(best_f1))
        res_message = '[VALIDATION PERFORMANCE: BEST F1]' + '\n' \
            + '[EPOCH:{} LOSS:{:.6f} F1:{:.6f} BEST F1:{:.6f}]\n'.format(
                best_loss_metrics[0], best_loss_metrics[1],
                best_loss_metrics[2], best_f1) \
            + '[BEST THRESHOLD:\n{}]\n'.format(best_th) \
            + '=' * 100 + '\n'
        print(res_message)
        return
Exemple #2
0
        for epoch in range(opt.max_epoch):
            scheduler.step()
            model.train()
            callback_manager.on_epoch_start(epoch)

            for i, data in enumerate(train_loader):
                callback_manager.on_batch_start(n_batch=i)
                data_input, label = data
                data_input = data_input.to(device)
                label = label.to(device).long()
                feature = model(data_input)
                output = metric_fc(feature, label)
                loss = criterion(output, label)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                iters = epoch * len(train_loader) + i

                metric = calculate_metrics(output, label)
                metric[Config.loss] = loss.item()
                metric['lr'] = get_lr(optimizer)
                callback_manager.on_batch_end(loss=loss.item(),
                                              n_batch=i,
                                              train_metric=metric)
                if Config.is_debug:
                    break
            if epoch % opt.save_interval == 0 or epoch == opt.max_epoch:
                save_model(model, opt.checkpoints_path, opt.backbone, epoch)
                save_model(metric_fc, opt.checkpoints_path, opt.metric, epoch)
Exemple #3
0
class TrainNetwork(object):
    """The main train network"""

    def __init__(self, args):
        super(TrainNetwork, self).__init__()
        self.args = args
        self.dur_time = 0
        self.logger = self._init_log()

        if not torch.cuda.is_available():
            self.logger.info('no gpu device available')
            sys.exit(1)

        self._init_hyperparam()
        self._init_random_and_device()
        self._init_model()

    def _init_hyperparam(self):
        if 'cifar100' == self.args.train_dataset:
            # cifar10:  6000 images per class, 10 classes, 50000 training images and 10000 test images
            # cifar100: 600 images per class, 100 classes, 500 training images and 100 testing images per class
            self.args.num_classes = 100
            self.args.layers = 20
            self.args.data = '/train_tiny_data/train_data/cifar100'
        elif 'imagenet' == self.args.train_dataset:
            self.args.data = '/train_data/imagenet'
            self.args.num_classes = 1000
            self.args.weight_decay = 3e-5
            self.args.report_freq = 100
            self.args.init_channels = 50
            self.args.drop_path_prob = 0
        elif 'tiny-imagenet' == self.args.train_dataset:
            self.args.data = '/train_tiny_data/train_data/tiny-imagenet'
            self.args.num_classes = 200
        elif 'food101' == self.args.train_dataset:
            self.args.data = '/train_tiny_data/train_data/food-101'
            self.args.num_classes = 101
            self.args.init_channels = 48

    def _init_log(self):
        self.args.save = '../logs/eval/' + self.args.arch + '/' + self.args.train_dataset + '/eval-{}-{}'.format(self.args.save, time.strftime('%Y%m%d-%H%M'))
        dutils.create_exp_dir(self.args.save, scripts_to_save=None)

        log_format = '%(asctime)s %(message)s'
        logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                            format=log_format, datefmt='%m/%d %I:%M:%S %p')
        fh = logging.FileHandler(os.path.join(self.args.save, 'log.txt'))
        fh.setFormatter(logging.Formatter(log_format))
        logger = logging.getLogger('Architecture Training')
        logger.addHandler(fh)
        return logger

    def _init_random_and_device(self):
        # Set random seed and cuda device
        np.random.seed(self.args.seed)
        cudnn.benchmark = True
        torch.manual_seed(self.args.seed)
        cudnn.enabled = True
        torch.cuda.manual_seed(self.args.seed)
        max_free_gpu_id, gpus_info = dutils.get_gpus_memory_info()
        self.device_id = max_free_gpu_id
        self.gpus_info = gpus_info
        self.device = torch.device('cuda:{}'.format(0 if self.args.multi_gpus else self.device_id))

    def _init_model(self):

        self.train_queue, self.valid_queue = self._load_dataset_queue()

        def _init_scheduler():
            if 'cifar' in self.args.train_dataset:
                scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, float(self.args.epochs))
            else:
                scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, self.args.decay_period,
                                                            gamma=self.args.gamma)
            return scheduler

        genotype = eval('geno_types.%s' % self.args.arch)
        reduce_level = (0 if 'cifar10' in self.args.train_dataset else 0)
        model = EvalNetwork(self.args.init_channels, self.args.num_classes, 0,
                            self.args.layers, self.args.auxiliary, genotype, reduce_level)

        # Try move model to multi gpus
        if torch.cuda.device_count() > 1 and self.args.multi_gpus:
            self.logger.info('use: %d gpus', torch.cuda.device_count())
            model = nn.DataParallel(model)
        else:
            self.logger.info('gpu device = %d' % self.device_id)
            torch.cuda.set_device(self.device_id)
        self.model = model.to(self.device)

        self.logger.info('param size = %fM', dutils.calc_parameters_count(model))

        criterion = nn.CrossEntropyLoss()
        if self.args.num_classes >= 50:
            criterion = CrossEntropyLabelSmooth(self.args.num_classes, self.args.label_smooth)
        self.criterion = criterion.to(self.device)

        if self.args.opt == 'adam':
            self.optimizer = torch.optim.Adamax(
                model.parameters(),
                self.args.learning_rate,
                weight_decay=self.args.weight_decay
            )
        elif self.args.opt == 'adabound':
            self.optimizer = AdaBound(model.parameters(),
            self.args.learning_rate,
            weight_decay=self.args.weight_decay)
        else:
            self.optimizer = torch.optim.SGD(
                model.parameters(),
                self.args.learning_rate,
                momentum=self.args.momentum,
                weight_decay=self.args.weight_decay
            )

        self.best_acc_top1 = 0
        # optionally resume from a checkpoint
        if self.args.resume:
            if os.path.isfile(self.args.resume):
                print("=> loading checkpoint {}".format(self.args.resume))
                checkpoint = torch.load(self.args.resume)
                self.dur_time = checkpoint['dur_time']
                self.args.start_epoch = checkpoint['epoch']
                self.best_acc_top1 = checkpoint['best_acc_top1']
                self.args.drop_path_prob = checkpoint['drop_path_prob']
                self.model.load_state_dict(checkpoint['state_dict'])
                self.optimizer.load_state_dict(checkpoint['optimizer'])
                print("=> loaded checkpoint '{}' (epoch {})".format(self.args.resume, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(self.args.resume))

        self.scheduler = _init_scheduler()
        # reload the scheduler if possible
        if self.args.resume and os.path.isfile(self.args.resume):
            checkpoint = torch.load(self.args.resume)
            self.scheduler.load_state_dict(checkpoint['scheduler'])

    def _load_dataset_queue(self):
        if 'cifar' in self.args.train_dataset:
            train_transform, valid_transform = dutils.data_transforms_cifar(self.args)
            if 'cifar10' == self.args.train_dataset:
                train_data = dset.CIFAR10(root=self.args.data, train=True, download=True, transform=train_transform)
                valid_data = dset.CIFAR10(root=self.args.data, train=False, download=True, transform=valid_transform)
            else:
                train_data = dset.CIFAR100(root=self.args.data, train=True, download=True, transform=train_transform)
                valid_data = dset.CIFAR100(root=self.args.data, train=False, download=True, transform=valid_transform)

            train_queue = torch.utils.data.DataLoader(
                train_data, batch_size = self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4
            )
            valid_queue = torch.utils.data.DataLoader(
                valid_data, batch_size = self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4
            )
        elif 'tiny-imagenet' == self.args.train_dataset:
            train_transform, valid_transform = dutils.data_transforms_tiny_imagenet()
            train_data = dartsdset.TinyImageNet200(self.args.data, train=True, download=True, transform=train_transform)
            valid_data = dartsdset.TinyImageNet200(self.args.data, train=False, download=True, transform=valid_transform)
            train_queue = torch.utils.data.DataLoader(
                train_data, batch_size=self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4
            )
            valid_queue = torch.utils.data.DataLoader(
                valid_data, batch_size=self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4
            )
        elif 'imagenet' == self.args.train_dataset:
            traindir = os.path.join(self.args.data, 'train')
            validdir = os.path.join(self.args.data, 'val')
            train_transform, valid_transform = dutils.data_transforms_imagenet()
            train_data = dset.ImageFolder(
                traindir,train_transform)
            valid_data = dset.ImageFolder(
                validdir,valid_transform)

            train_queue = torch.utils.data.DataLoader(
                train_data, batch_size=self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4)

            valid_queue = torch.utils.data.DataLoader(
                valid_data, batch_size=self.args.batch_size, shuffle=False, pin_memory=True, num_workers=4)
        elif 'food101' == self.args.train_dataset:
            traindir = os.path.join(self.args.data, 'train')
            validdir = os.path.join(self.args.data, 'val')
            train_transform, valid_transform = dutils.data_transforms_food101()
            train_data = dset.ImageFolder(
                traindir,train_transform)
            valid_data = dset.ImageFolder(
                validdir,valid_transform)

            train_queue = torch.utils.data.DataLoader(
                train_data, batch_size=self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4)

            valid_queue = torch.utils.data.DataLoader(
                valid_data, batch_size=self.args.batch_size, shuffle=False, pin_memory=True, num_workers=4)

        return train_queue, valid_queue

    def run(self):
        self.logger.info('args = %s', self.args)
        run_start = time.time()
        for epoch in range(self.args.start_epoch, self.args.epochs):
            self.scheduler.step()
            self.logger.info('epoch % d / %d  lr %e', epoch, self.args.epochs, self.scheduler.get_lr()[0])

            if self.args.no_dropout:
                self.model._drop_path_prob = 0
            else:
                self.model._drop_path_prob = self.args.drop_path_prob * epoch / self.args.epochs
                self.logger.info('drop_path_prob %e', self.model._drop_path_prob)

            train_acc, train_obj = self.train()
            self.logger.info('train loss %e, train acc %f', train_obj, train_acc)

            valid_acc_top1, valid_acc_top5, valid_obj = self.infer()
            self.logger.info('valid loss %e, top1 valid acc %f top5 valid acc %f',
                        valid_obj, valid_acc_top1, valid_acc_top5)
            self.logger.info('best valid acc %f', self.best_acc_top1)

            is_best = False
            if valid_acc_top1 > self.best_acc_top1:
                self.best_acc_top1 = valid_acc_top1
                is_best = True

            dutils.save_checkpoint({
                'epoch': epoch+1,
                'dur_time': self.dur_time + time.time() - run_start,
                'state_dict': self.model.state_dict(),
                'drop_path_prob': self.args.drop_path_prob,
                'best_acc_top1': self.best_acc_top1,
                'optimizer': self.optimizer.state_dict(),
                'scheduler': self.scheduler.state_dict()
            }, is_best, self.args.save)
        self.logger.info('train epoches %d, best_acc_top1 %f, dur_time %s',
                         self.args.epochs, self.best_acc_top1, dutils.calc_time(self.dur_time + time.time() - run_start))

    def train(self):
        objs = dutils.AverageMeter()
        top1 = dutils.AverageMeter()
        top5 = dutils.AverageMeter()

        self.model.train()

        for step, (input, target) in enumerate(self.train_queue):

            input = input.cuda(self.device, non_blocking=True)
            target = target.cuda(self.device, non_blocking=True)

            self.optimizer.zero_grad()
            logits, logits_aux = self.model(input)
            loss = self.criterion(logits, target)
            if self.args.auxiliary:
                loss_aux = self.criterion(logits_aux, target)
                loss += self.args.auxiliary_weight*loss_aux
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_clip)
            self.optimizer.step()

            prec1, prec5 = dutils.accuracy(logits, target, topk=(1,5))
            n = input.size(0)
            objs.update(loss.item(), n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

            if step % args.report_freq == 0:
                self.logger.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)

        return top1.avg, objs.avg

    def infer(self):
        objs = dutils.AverageMeter()
        top1 = dutils.AverageMeter()
        top5 = dutils.AverageMeter()
        self.model.eval()
        with torch.no_grad():
            for step, (input, target) in enumerate(self.valid_queue):
                input = input.cuda(self.device, non_blocking=True)
                target = target.cuda(self.device, non_blocking=True)

                logits, _ = self.model(input)
                loss = self.criterion(logits, target)

                prec1, prec5 = dutils.accuracy(logits, target, topk=(1,5))
                n = input.size(0)
                objs.update(loss.item(), n)
                top1.update(prec1.item(), n)
                top5.update(prec5.item(), n)

                if step % args.report_freq == 0:
                    self.logger.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
            return top1.avg, top5.avg, objs.avg
Exemple #4
0
class OriginalModel(BaseModel):
    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        """Add new model-specific options and rewrite default values for existing options.

        Parameters:
            parser -- the option parser
            is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options.

        Returns:
            the modified parser.
        """
        preprocess = 'normalize,mulaw,cqt'
        parser.set_defaults(preprocess=preprocess, flatten=True)
        parser.add_argument('--wavenet_layers', type=int, default=40, help='wavenet layers')
        parser.add_argument('--wavenet_blocks', type=int, default=10, help='wavenet layers')
        parser.add_argument('--width', type=int, default=128, help='width')
        parser.add_argument('--dc_lambda', type=float, default=0.01, help='dc lambda') 
        parser.add_argument('--tanh', action='store_true', help='tanh')
        parser.add_argument('--sigmoid', action='store_true', help='sigmoid')
        return parser

    def __init__(self, opt):
        BaseModel.__init__(self, opt)  # call the initialization method of BaseModel
        self.loss_names = ['C_A_right', 'C_B_right', 'C_A_wrong', 'C_B_wrong', 'D_A', 'D_B']
        if opt.isTrain:
            self.output_names = [] # ['aug_A', 'aug_B', 'rec_A', 'rec_B']
        else:
            self.output_names = ['real_A', 'real_B', 'fake_B', 'fake_A']
        self.params_names = ['params_A', 'params_B'] * 2
        self.model_names = ['E', 'C', 'D_A', 'D_B']

        # use get generator
        self.netE = getGenerator(self.devices[0], opt)
        self.netC = getDiscriminator(opt, self.devices[0])

        self.netD_A = WaveNet(opt.mu+1, opt.wavenet_layers, opt.wavenet_blocks, 
                              opt.width, 256, 256,
                              opt.tensor_height, 1, 1).to(self.devices[-1]) # opt.pool_length, opt.pool_length
        self.netD_B = WaveNet(opt.mu+1, opt.wavenet_layers, opt.wavenet_blocks,
                              opt.width, 256, 256,
                              opt.tensor_height, 1, 1).to(self.devices[-1]) # opt.pool_length, opt.pool_length
        self.softmax = nn.LogSoftmax(dim=1) # (1, 256, audio_len) -> pick 256
        
        if self.isTrain:
            self.A_target = torch.zeros(opt.batch_size).to(self.devices[0])
            self.B_target = torch.ones(opt.batch_size).to(self.devices[0])
            self.criterionDC = nn.MSELoss(reduction='mean')
            self.criterionDecode = nn.CrossEntropyLoss(reduction='mean')
            self.optimizer_C = AdaBound(self.netC.parameters(), lr=opt.lr, final_lr=0.1)
            self.optimizer_D = AdaBound(itertools.chain(self.netE.parameters(), self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, final_lr=0.1)
            self.optimizers = [self.optimizer_C, self.optimizer_D] 
        else:
            self.preprocesses = []
            # TODO change structure of test.py and setup() instead
            load_suffix = str(opt.load_iter) if opt.load_iter > 0 else opt.epoch
            self.load_networks(load_suffix)
            self.netC.eval()
            self.netD_A.eval()
            self.netD_B.eval()
             
            self.infer_A = NVWaveNet(**(self.netD_A.export_weights()))
            self.infer_B = NVWaveNet(**(self.netD_B.export_weights()))

    def set_input(self, input): 
        A, params_A = input[0]  
        B, params_B = input[1] 
         
        self.real_A = params_A['original'].to(self.devices[0])
        self.real_B = params_B['original'].to(self.devices[0])
        self.aug_A = A.to(self.devices[0])
        self.aug_B = B.to(self.devices[0])

        self.params_A = self.decollate_params(params_A)
        self.params_B = self.decollate_params(params_B)

    def get_indices(self, y):
        y = (y + 1.) * .5 * self.opt.mu
        return y.long() 

    def inv_indices(self, y):
        return y.float() / self.opt.mu * 2. - 1.
 
    def train(self): 
        self.optimizer_C.zero_grad() 
        encoded_A = self.netE(self.aug_A) # Input range: (-1, 1) Output: R^64
        encoded_A = nn.functional.interpolate(encoded_A, size=self.opt.audio_length).to(self.devices[-1])
        pred_C_A = self.netC(encoded_A)
        self.loss_C_A_right = self.opt.dc_lambda * self.criterionDC(pred_C_A, self.A_target)
        self.loss_C_A_right.backward()

        encoded_B = self.netE(self.aug_B) 
        encoded_B = nn.functional.interpolate(encoded_B, size=self.opt.audio_length).to(self.devices[-1])
        pred_C_B = self.netC(encoded_B)
        self.loss_C_B_right = self.opt.dc_lambda * self.criterionDC(pred_C_B, self.B_target)
        self.loss_C_B_right.backward()
        self.optimizer_C.step()
  
        self.optimizer_D.zero_grad() 
        encoded_A = self.netE(self.aug_A) # Input range: (-1, 1) Output: R^64
        encoded_A = nn.functional.interpolate(encoded_A, size=self.opt.audio_length).to(self.devices[-1])
        pred_C_A = self.netC(encoded_A) 
        self.loss_C_A_wrong = self.criterionDC(pred_C_A, self.A_target)
        real_A = self.get_indices(self.real_A).to(self.devices[-1])
        pred_D_A = self.netD_A((encoded_A, real_A))
        self.loss_D_A = self.criterionDecode(pred_D_A, real_A)
        loss = self.loss_D_A - self.opt.dc_lambda * self.loss_C_A_wrong
        loss.backward()
        
        encoded_B = self.netE(self.aug_B)
        encoded_B = nn.functional.interpolate(encoded_B, size=self.opt.audio_length).to(self.devices[-1])
        pred_C_B = self.netC(encoded_B) 
        self.loss_C_B_wrong = self.criterionDC(pred_C_B, self.B_target)
        real_B = self.get_indices(self.real_B).to(self.devices[-1]) 
        pred_D_B = self.netD_B((encoded_B, real_B))
        self.loss_D_B = self.criterionDecode(pred_D_B, real_B)
        loss = self.loss_D_B - self.opt.dc_lambda * self.loss_C_B_wrong
        loss.backward()
        self.optimizer_D.step() 
  
    def test(self):  
        with torch.no_grad():   
            encoded_A = self.netE(self.aug_A)
            encoded_B = self.netE(self.aug_B)
            self.fake_B = self.infer_A.infer(self.netD_A.get_cond_input(encoded_B), Impl.AUTO)
            self.fake_A = self.infer_B.infer(self.netD_B.get_cond_input(encoded_A), Impl.AUTO)
            self.fake_B = self.inv_indices(self.fake_B)
            self.fake_A = self.inv_indices(self.fake_A)
Exemple #5
0
def main(args):
    # Set up logging and devices
    args.save_dir = util.get_save_dir(args.save_dir, args.name, training=True)
    log = util.get_logger(args.save_dir, args.name)
    tbx = SummaryWriter(args.save_dir)
    device, args.gpu_ids = util.get_available_devices()
    log.info('Args: {}'.format(dumps(vars(args), indent=4, sort_keys=True)))
    args.batch_size *= max(1, len(args.gpu_ids))

    # Set random seed
    log.info('Using random seed {}...'.format(args.seed))
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # Get embeddings
    log.info('Loading embeddings...')
    word_vectors = util.torch_from_json(args.word_emb_file)
    char_vectors = util.torch_from_json(args.char_emb_file)

    # Get model
    log.info('Building model...')
    model = BiDAF(word_vectors=word_vectors,
                  char_vectors=char_vectors,
                  hidden_size=args.hidden_size,
                  drop_prob=args.drop_prob)
    model = nn.DataParallel(model, args.gpu_ids)
    if args.load_path:
        log.info('Loading checkpoint from {}...'.format(args.load_path))
        model, step = util.load_model(model, args.load_path, args.gpu_ids)
    else:
        step = 0
    model = model.to(device)
    model.train()
    ema = util.EMA(model, args.ema_decay)

    # Get saver
    saver = util.CheckpointSaver(args.save_dir,
                                 max_checkpoints=args.max_checkpoints,
                                 metric_name=args.metric_name,
                                 maximize_metric=args.maximize_metric,
                                 log=log)

    # Get optimizer and scheduler
    #optimizer = optim.Adamax(model.parameters(), args.lr,
    #                            weight_decay=args.l2_wd)
    #scheduler = sched.LambdaLR(optimizer, lambda s: 1.)  # Constant LR
    optimizer = AdaBound(model.parameters())

    # Get data loader
    log.info('Building dataset...')
    train_dataset = SQuAD(args.train_record_file, args.use_squad_v2)
    train_loader = data.DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   collate_fn=collate_fn)
    dev_dataset = SQuAD(args.dev_record_file, args.use_squad_v2)
    dev_loader = data.DataLoader(dev_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers,
                                 collate_fn=collate_fn)

    # Train
    log.info('Training...')
    steps_till_eval = args.eval_steps
    epoch = step // len(train_dataset)
    while epoch != args.num_epochs:
        epoch += 1
        log.info('Starting epoch {}...'.format(epoch))
        with torch.enable_grad(), \
                tqdm(total=len(train_loader.dataset)) as progress_bar:
            for cw_idxs, cc_idxs, qw_idxs, qc_idxs, y1, y2, ids, cwf in train_loader:
                # Setup for forward
                cw_idxs = cw_idxs.to(device)
                qw_idxs = qw_idxs.to(device)
                cc_idxs = cc_idxs.to(device)
                qc_idxs = qc_idxs.to(device)
                batch_size = cw_idxs.size(0)
                cwf = cwf.to(device)
                optimizer.zero_grad()

                # Forward
                log_p1, log_p2 = model(cc_idxs, qc_idxs, cw_idxs, qw_idxs, cwf)
                y1, y2 = y1.to(device), y2.to(device)
                loss = F.nll_loss(log_p1, y1) + F.nll_loss(log_p2, y2)
                loss_val = loss.item()

                # Backward
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(),
                                         args.max_grad_norm)
                optimizer.step()
                #scheduler.step(step // batch_size)
                ema(model, step // batch_size)

                # Log info
                step += batch_size
                progress_bar.update(batch_size)
                progress_bar.set_postfix(epoch=epoch, NLL=loss_val)
                tbx.add_scalar('train/NLL', loss_val, step)
                tbx.add_scalar('train/LR', optimizer.param_groups[0]['lr'],
                               step)

                steps_till_eval -= batch_size
                if steps_till_eval <= 0:
                    steps_till_eval = args.eval_steps

                    # Evaluate and save checkpoint
                    log.info('Evaluating at step {}...'.format(step))
                    ema.assign(model)
                    results, pred_dict = evaluate(model, dev_loader, device,
                                                  args.dev_eval_file,
                                                  args.max_ans_len,
                                                  args.use_squad_v2)
                    saver.save(step, model, results[args.metric_name], device)
                    ema.resume(model)

                    # Log to console
                    results_str = ', '.join('{}: {:05.2f}'.format(k, v)
                                            for k, v in results.items())
                    log.info('Dev {}'.format(results_str))

                    # Log to TensorBoard
                    log.info('Visualizing in TensorBoard...')
                    for k, v in results.items():
                        tbx.add_scalar('dev/{}'.format(k), v, step)
                    util.visualize(tbx,
                                   pred_dict=pred_dict,
                                   eval_path=args.dev_eval_file,
                                   step=step,
                                   split='dev',
                                   num_visuals=args.num_visuals)
Exemple #6
0
class NNProcess:
    def __init__(self,
                 import_trained=(False, ''),
                 model_pretrained=(True, True),
                 save_model=True,
                 resnet_depth=50,
                 lr=1e-3,
                 momentum=0.09,
                 nesterov=False,
                 threshold=0.5,
                 epochs=50,
                 batch_size=64,
                 train_val_split=0.7,
                 data_interval='1min',
                 predict_period=1,
                 mins_interval=30,
                 start_date='2020-08-24',
                 end_date='2020-08-29'):
        '''
        import_trained = (whether if you want to import a trained pth file, if yes what is the filename)
        model_pretrained = (whether if you want to import a pretrained model, whether if you want to only want to train the linear layers)
        save_model = whether to save model when training finished
        resnet_depth = to decide the depth of the residual network
        lr = learning rate for the stochastic gradient descend optimizer
        momentum = momentum for the sgd
        nesterov = whether to use nesterov momentum for sgd
        threshold = investment threshold, advices to invest if the returned probability > threshold
        epochs = training hyperparameter: the number of times the entire dataset is exposed to the neural network
        batch_size = training hyperparameter: the number of items to show the dataset at once
        train_val_split = training hyperparameter: how to split the data
        data_interval = the time interval between each datapoint
        predict_period = the amount of time period to predict forwards
        days = the amount of days to use
        mins_interval = the amount of minutes to show in the graph
        start_date = the first date to get data - data for each day would start from 9am and end at 8pm
        end_date = the last date to get data - data for each day would start from 9am and end at 8pm
        '''

        self.__import_trained = import_trained
        self.__model_pretrained = model_pretrained
        self.__saveModel = save_model
        self.__resnet_depth = resnet_depth
        self.__threshold = threshold
        self.__epochs = epochs
        self.__batch_size = batch_size
        data = dataset.stockGraphGenerator(split=train_val_split,
                                           interval=data_interval,
                                           predict_period=predict_period,
                                           mins_interval=mins_interval,
                                           start_date=start_date,
                                           end_date=end_date,
                                           stride=15)
        self.__train_set = torch.utils.data.DataLoader(
            data.train_data, batch_size=self.__batch_size, shuffle=False)
        self.__test_set = torch.utils.data.DataLoader(
            data.test_data, batch_size=self.__batch_size, shuffle=False)
        self.__model = self.__loadmodelInstance(
        ) if self.__import_trained[0] else self.__createmodelInstance()
        self.__criterion = nn.BCELoss()
        self.__optim = AdaBound(self.__model.parameters(),
                                amsbound=True,
                                lr=lr,
                                final_lr=0.1)
        self.__trainHist = [[], [], [], []]

    def __loadmodelInstance(self):
        model = torch.load(self.__import_trained[1] + '.pth')
        return model.cuda() if torch.cuda.is_available() else model

    def __createmodelInstance(self):
        return ResNetClassifier(
            pretrained=self.__model_pretrained,
            resnet_depth=self.__resnet_depth).cuda(
            ) if torch.cuda.is_available() else ResNetClassifier(
                pretrained=self.__model_pretrained,
                resnet_depth=self.__resnet_depth)

    def __softmax(self, x):
        e_x = np.exp(x - np.max(x))
        return e_x / e_x.sum()

    def __setModelTrain(self):
        self.__model = self.__model.train()

    def __setModelEval(self):
        self.__model = self.__model.eval()

    def fetch_model(self):
        return self.__model.eval()

    def fetch_trainHist(self):
        return self.__trainHist

    def __get_acc(self, output, label):
        output = torch.round(output)
        num_correct = sum(
            [1 if output[i] == label[i] else 0 for i in range(len(output))])
        return num_correct / output.shape[0]

    def train(self):
        for epochs in range(self.__epochs):
            start_time = time.time()
            avg_train_loss, avg_test_loss = 0, 0
            avg_train_acc, avg_test_acc = 0, 0
            train_total, test_total = 0, 0
            self.__setModelTrain()
            for im, label in self.__train_set:
                train_total += 1
                im, label = Variable(im), Variable(label)
                pred = self.__model(im)
                train_loss = self.__criterion(pred, label)
                self.__optim.zero_grad()
                train_loss.backward()
                self.__optim.step()
                avg_train_loss += train_loss.data.tolist()
                avg_train_acc += self.__get_acc(pred, label)
                print(
                    'Training Batch No.: {:3d}\nTraining Loss: {:.5f} ; Training Acc.: {:.5f}'
                    .format(train_total, train_loss.data.tolist(),
                            self.__get_acc(pred, label)))

            self.__setModelEval()
            for im, label in self.__test_set:
                test_total += 1
                im, label = Variable(im, requires_grad=False), Variable(
                    label, requires_grad=False)
                pred = self.__model(im)
                test_loss = self.__criterion(pred, label)
                avg_test_loss += test_loss.data.tolist()
                avg_test_acc += self.__get_acc(pred, label)
                print(
                    'Testing Batch No.: {:3d}\nTesting Loss: {:.5f} ; Testing Acc.: {:.5f}'
                    .format(test_total, test_loss.data.tolist(),
                            self.__get_acc(pred, label)))

            self.__trainHist[0].append(avg_train_loss / train_total)
            self.__trainHist[1].append(avg_test_loss / test_total)
            self.__trainHist[2].append(avg_train_acc / train_total)
            self.__trainHist[3].append(avg_test_acc / test_total)
            print(
                'Epoch: {:3d} / {:3d}\nAverage Training Loss: {:.6f} ; Average Validation Loss: {:.6f}\nTrain Accuracy: {:.3f} ; Test Accuracy: {:.3f}\nTime Taken: {:.6f}\n'
                .format(epochs + 1, self.__epochs,
                        avg_train_loss / train_total,
                        avg_test_loss / test_total,
                        avg_train_acc / train_total,
                        avg_test_acc / train_total,
                        time.time() - start_time))

        if self.__saveModel:
            torch.save(self.__model, './resnet_market_predictor.pth')
class SRSolver(BaseSolver):
    def __init__(self, opt):
        super(SRSolver, self).__init__(opt)
        self.train_opt = opt['solver']
        self.LR = self.Tensor()
        self.HR = self.Tensor()
        self.SR = None

        self.records = {'train_loss': [],
                        'val_loss': [],
                        'psnr': [],
                        'ssim': [],
                        'lr': []}

        self.model = create_model(opt)
        self.print_network()

        if self.is_train:
            self.model.train()

            # set cl_loss
            if self.use_cl:
                self.cl_weights = self.opt['solver']['cl_weights']
                assert self.cl_weights, "[Error] 'cl_weights' is not be declared when 'use_cl' is true"

            # set loss
            loss_type = self.train_opt['loss_type']
            if loss_type == 'l1':
                self.criterion_pix = nn.L1Loss()
            elif loss_type == 'l2':
                self.criterion_pix = nn.MSELoss()
            else:
                raise NotImplementedError('Loss type [%s] is not implemented!'%loss_type)

            if self.use_gpu:
                self.criterion_pix = self.criterion_pix.cuda()

            # set optimizer
            weight_decay = self.train_opt['weight_decay'] if self.train_opt['weight_decay'] else 0
            optim_type = self.train_opt['type'].upper()
            if optim_type == "ADAM":
                self.optimizer = optim.Adam(self.model.parameters(),
                                            lr=self.train_opt['learning_rate'], weight_decay=weight_decay)
            elif optim_type == 'ADABOUND':
                self.optimizer = AdaBound(self.model.parameters(),
                                          lr = self.train_opt['learning_rate'], weight_decay=weight_decay)
            elif optim_type == 'SGD':
                self.optimizer = optim.SGD(self.model.parameters(),
                                           lr = self.train_opt['learning_rate'], momentum=0.90, weight_decay=weight_decay)
            else:
                raise NotImplementedError('Loss type [%s] is not implemented!' % optim_type)

            # set lr_scheduler
            if self.train_opt['lr_scheme'].lower() == 'multisteplr':
                self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer,
                                                                self.train_opt['lr_steps'],
                                                                self.train_opt['lr_gamma'])
            elif self.train_opt['lr_scheme'].lower() == 'cos':
                self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer,
                                                                      T_max =  self.opt['solver']['num_epochs'],
                                                                      eta_min = self.train_opt['lr_min']
                                                                      )
            else:
                raise NotImplementedError('Only MultiStepLR scheme is supported!')

        self.load()

        print('===> Solver Initialized : [%s] || Use CL : [%s] || Use GPU : [%s]'%(self.__class__.__name__,
                                                                                       self.use_cl, self.use_gpu))
        if self.is_train:
            print("optimizer: ", self.optimizer)
            if self.train_opt['lr_scheme'].lower() == 'multisteplr':
                print("lr_scheduler milestones: %s   gamma: %f"%(self.scheduler.milestones, self.scheduler.gamma))

    def _net_init(self, init_type='kaiming'):
        print('==> Initializing the network using [%s]'%init_type)
        init_weights(self.model, init_type)


    def feed_data(self, batch, need_HR=True):
        input = batch['LR']
        self.LR.resize_(input.size()).copy_(input)

        if need_HR:
            target = batch['HR']
            self.HR.resize_(target.size()).copy_(target)


    def train_step(self):
        self.model.train()
        self.optimizer.zero_grad()

        loss_batch = 0.0
        sub_batch_size = int(self.LR.size(0) / self.split_batch)
        for i in range(self.split_batch):
            loss_sbatch = 0.0
            split_LR = self.LR.narrow(0, i*sub_batch_size, sub_batch_size)
            split_HR = self.HR.narrow(0, i*sub_batch_size, sub_batch_size)
            if self.use_cl:
                outputs = self.model(split_LR)
                loss_steps = [self.criterion_pix(sr, split_HR) for sr in outputs]
                for step in range(len(loss_steps)):
                    loss_sbatch += self.cl_weights[step] * loss_steps[step]
            else:
                output = self.model(split_LR)
                loss_sbatch = self.criterion_pix(output, split_HR)

            loss_sbatch /= self.split_batch
            loss_sbatch.backward()

            loss_batch += (loss_sbatch.item())

        # for stable training
        if loss_batch < self.skip_threshold * self.last_epoch_loss:
            self.optimizer.step()
            self.last_epoch_loss = loss_batch
        else:
            print('[Warning] Skip this batch! (Loss: {})'.format(loss_batch))

        self.model.eval()
        return loss_batch

    def test(self):
        self.model.eval()
        with torch.no_grad(): # 执行完forward
            forward_func = self._overlap_crop_forward if self.use_chop else self.model.forward
            if self.self_ensemble and not self.is_train:
                SR = self._forward_x8(self.LR, forward_func)
            else:
                SR = forward_func(self.LR)

            if isinstance(SR, list):
                self.SR = SR[-1]
            else:
                self.SR = SR

        self.model.train()
        if self.is_train:
            loss_pix = self.criterion_pix(self.SR, self.HR)
            return loss_pix.item()


    def _forward_x8(self, x, forward_function):
        """
        self ensemble
        """
        def _transform(v, op):
            v = v.float()

            v2np = v.data.cpu().numpy()
            if op == 'v':
                tfnp = v2np[:, :, :, ::-1].copy()
            elif op == 'h':
                tfnp = v2np[:, :, ::-1, :].copy()
            elif op == 't':
                tfnp = v2np.transpose((0, 1, 3, 2)).copy()

            ret = self.Tensor(tfnp)

            return ret

        lr_list = [x]
        for tf in 'v', 'h', 't':
            lr_list.extend([_transform(t, tf) for t in lr_list])

        sr_list = []
        for aug in lr_list:
            sr = forward_function(aug)
            if isinstance(sr, list):
                sr_list.append(sr[-1])
            else:
                sr_list.append(sr)

        for i in range(len(sr_list)):
            if i > 3:
                sr_list[i] = _transform(sr_list[i], 't')
            if i % 4 > 1:
                sr_list[i] = _transform(sr_list[i], 'h')
            if (i % 4) % 2 == 1:
                sr_list[i] = _transform(sr_list[i], 'v')

        output_cat = torch.cat(sr_list, dim=0)
        output = output_cat.mean(dim=0, keepdim=True)

        return output


    def _overlap_crop_forward(self, x, shave=10, min_size=100000, bic=None):
        """
        chop for less memory consumption during test
        """
        n_GPUs = 2
        scale = self.scale
        b, c, h, w = x.size()
        h_half, w_half = h // 2, w // 2
        h_size, w_size = h_half + shave, w_half + shave
        lr_list = [
            x[:, :, 0:h_size, 0:w_size],
            x[:, :, 0:h_size, (w - w_size):w],
            x[:, :, (h - h_size):h, 0:w_size],
            x[:, :, (h - h_size):h, (w - w_size):w]]

        if bic is not None:
            bic_h_size = h_size*scale
            bic_w_size = w_size*scale
            bic_h = h*scale
            bic_w = w*scale
            
            bic_list = [
                bic[:, :, 0:bic_h_size, 0:bic_w_size],
                bic[:, :, 0:bic_h_size, (bic_w - bic_w_size):bic_w],
                bic[:, :, (bic_h - bic_h_size):bic_h, 0:bic_w_size],
                bic[:, :, (bic_h - bic_h_size):bic_h, (bic_w - bic_w_size):bic_w]]

        if w_size * h_size < min_size:
            sr_list = []
            for i in range(0, 4, n_GPUs):
                lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)
                if bic is not None:
                    bic_batch = torch.cat(bic_list[i:(i + n_GPUs)], dim=0)

                sr_batch_temp = self.model(lr_batch)

                if isinstance(sr_batch_temp, list):
                    sr_batch = sr_batch_temp[-1]
                else:
                    sr_batch = sr_batch_temp

                sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))
        else:
            sr_list = [
                self._overlap_crop_forward(patch, shave=shave, min_size=min_size) \
                for patch in lr_list
                ]

        h, w = scale * h, scale * w
        h_half, w_half = scale * h_half, scale * w_half
        h_size, w_size = scale * h_size, scale * w_size
        shave *= scale

        output = x.new(b, c, h, w)
        output[:, :, 0:h_half, 0:w_half] \
            = sr_list[0][:, :, 0:h_half, 0:w_half]
        output[:, :, 0:h_half, w_half:w] \
            = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
        output[:, :, h_half:h, 0:w_half] \
            = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
        output[:, :, h_half:h, w_half:w] \
            = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]

        return output


    def save_checkpoint(self, epoch, is_best):
        """
        save checkpoint to experimental dir
        """
        filename = os.path.join(self.checkpoint_dir, 'last_ckp.pth')
        print('===> Saving last checkpoint to [%s] ...]'%filename)
        ckp = {
            'epoch': epoch,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'best_pred': self.best_pred,
            'best_epoch': self.best_epoch,
            'records': self.records
        }
        torch.save(ckp, filename)
        if is_best:
            print('===> Saving best checkpoint to [%s] ...]' % filename.replace('last_ckp','best_ckp'))
            torch.save(ckp, filename.replace('last_ckp','best_ckp'))

        if epoch % self.train_opt['save_ckp_step'] == 0:
            print('===> Saving checkpoint [%d] to [%s] ...]' % (epoch,
                                                                filename.replace('last_ckp','epoch_%d_ckp.pth'%epoch)))

            torch.save(ckp, filename.replace('last_ckp','epoch_%d_ckp.pth'%epoch))


    def load(self):
        """
        load or initialize network
        """
        if (self.is_train and self.opt['solver']['pretrain']) or not self.is_train:
            model_path = self.opt['solver']['pretrained_path']
            if model_path is None: raise ValueError("[Error] The 'pretrained_path' does not declarate in *.json")

            print('===> Loading model from [%s]...' % model_path)
            if self.is_train:
                checkpoint = torch.load(model_path)
                self.model.load_state_dict(checkpoint['state_dict'])

                # if self.opt['solver']['pretrain'] == 'resume':
                #     self.cur_epoch = checkpoint['epoch'] + 1
                #     self.optimizer.load_state_dict(checkpoint['optimizer'])
                #     self.best_pred = checkpoint['best_pred']
                #     self.best_epoch = checkpoint['best_epoch']
                #     self.records = checkpoint['records']

            else:
                checkpoint = torch.load(model_path)
                if 'state_dict' in checkpoint.keys(): checkpoint = checkpoint['state_dict']
                load_func = self.model.load_state_dict if isinstance(self.model, nn.DataParallel) \
                    else self.model.module.load_state_dict
                load_func(checkpoint)
        else:
            print('===> Initialize model')
            self._net_init()


    def get_current_visual(self, need_np=True, need_HR=True):
        """
        return LR SR (HR) images
        """
        out_dict = OrderedDict()
        out_dict['LR'] = self.LR.data[0].float().cpu()
        out_dict['SR'] = self.SR.data[0].float().cpu()
        if need_np:  out_dict['LR'], out_dict['SR'] = util.Tensor2np([out_dict['LR'], out_dict['SR']],
                                                                        self.opt['rgb_range'])
        if need_HR:
            out_dict['HR'] = self.HR.data[0].float().cpu()
            if need_np: out_dict['HR'] = util.Tensor2np([out_dict['HR']],
                                                           self.opt['rgb_range'])[0]
        return out_dict


    def save_current_visual(self, epoch, iter):
        """
        save visual results for comparison
        """
        if epoch % self.save_vis_step == 0:
            visuals_list = []
            visuals = self.get_current_visual(need_np=False)
            visuals_list.extend([util.quantize(visuals['HR'].squeeze(0), self.opt['rgb_range']),
                                 util.quantize(visuals['SR'].squeeze(0), self.opt['rgb_range'])])
            visual_images = torch.stack(visuals_list)
            visual_images = thutil.make_grid(visual_images, nrow=2, padding=5)
            visual_images = visual_images.byte().permute(1, 2, 0).numpy()
            misc.imsave(os.path.join(self.visual_dir, 'epoch_%d_img_%d.png' % (epoch, iter + 1)),
                        visual_images)


    def get_current_learning_rate(self):
        # return self.scheduler.get_lr()[-1]
        return self.optimizer.param_groups[0]['lr']


    def update_learning_rate(self, epoch):
        self.scheduler.step(epoch)


    def get_current_log(self):
        log = OrderedDict()
        log['epoch'] = self.cur_epoch
        log['best_pred'] = self.best_pred
        log['best_epoch'] = self.best_epoch
        log['records'] = self.records
        return log


    def set_current_log(self, log):
        self.cur_epoch = log['epoch']
        self.best_pred = log['best_pred']
        self.best_epoch = log['best_epoch']
        self.records = log['records']


    def save_current_log(self):
        data_frame = pd.DataFrame(
            data={'train_loss': self.records['train_loss']
                , 'val_loss': self.records['val_loss']
                , 'psnr': self.records['psnr']
                , 'ssim': self.records['ssim']
                , 'lr': self.records['lr']
                  },
            index=range(1, self.cur_epoch + 1)
        )
        data_frame.to_csv(os.path.join(self.records_dir, 'train_records.csv'),
                          index_label='epoch')


    def print_network(self):
        """
        print network summary including module and number of parameters
        """
        s, n = self.get_network_description(self.model)
        if isinstance(self.model, nn.DataParallel):
            net_struc_str = '{} - {}'.format(self.model.__class__.__name__,
                                                 self.model.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.model.__class__.__name__)

        print("==================================================")
        print("===> Network Summary\n")
        net_lines = []
        line = s + '\n'
        print(line)
        net_lines.append(line)
        line = 'Network structure: [{}], with parameters: [{:,d}]'.format(net_struc_str, n)
        print(line)
        net_lines.append(line)

        if self.is_train:
            with open(os.path.join(self.exp_root, 'network_summary.txt'), 'w') as f:
                f.writelines(net_lines)

        print("==================================================")
Exemple #8
0
class TranslatorModel(BaseModel):
    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        """Add new model-specific options and rewrite default values for existing options.

        Parameters:
            parser -- the option parser
            is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options.

        Returns:
            the modified parser.
        """
        if is_train:
            opt, _ = parser.parse_known_args()
            preprocess = opt.preprocess \
                            .replace('mel', '') \
                            .replace('normalize', '') \
                            .replace('stft', '') + ',mulaw'
        else:
            preprocess = 'mulaw'
        parser.set_defaults(preprocess=preprocess)
        parser.add_argument('--wavenet_layers',
                            type=int,
                            default=30,
                            help='wavenet layers')
        parser.add_argument('--wavenet_blocks',
                            type=int,
                            default=10,
                            help='wavenet layers')
        parser.add_argument('--bottleneck',
                            type=int,
                            default=64,
                            help='bottleneck')
        parser.add_argument('--dc_width',
                            type=int,
                            default=128,
                            help='dc width')
        parser.add_argument('--width', type=int, default=128, help='width')
        parser.add_argument('--pool_length',
                            type=int,
                            default=1024,
                            help='pool length')
        parser.add_argument('--dc_lambda',
                            type=float,
                            default=0.01,
                            help='dc lambda')
        parser.add_argument('--dc_no_bias',
                            action='store_true',
                            help='dc bias')
        return parser

    def __init__(self, opt):
        BaseModel.__init__(self,
                           opt)  # call the initialization method of BaseModel
        self.loss_names = [
            'C_A_right', 'C_B_right', 'C_A_wrong', 'C_B_wrong', 'D_A', 'D_B'
        ]
        if opt.isTrain:
            self.output_names = []  # ['aug_A', 'aug_B', 'rec_A', 'rec_B']
        else:
            self.output_names = ['real_A', 'real_B', 'fake_B', 'fake_A']
        self.params_names = ['params_A', 'params_B'] * 2
        self.model_names = ['E', 'C', 'D_A', 'D_B']

        self.netE = TemporalEncoder(
            **{
                'width': opt.width,
                'bottleneck_width': opt.bottleneck,
                'pool_length': opt.pool_length,
            }).to(self.devices[0])
        # self.vector_length = opt.audio_length // opt.pool_length
        self.netC = DomainConfusion(3, 2, opt.bottleneck, opt.dc_width,
                                    opt.dc_lambda,
                                    not opt.dc_no_bias).to(self.devices[0])
        self.netD_A = WaveNet(
            opt.mu + 1, opt.wavenet_layers, opt.wavenet_blocks, opt.width, 256,
            256, opt.bottleneck, 1,
            1).to(self.devices[-1])  # opt.pool_length, opt.pool_length
        self.netD_B = WaveNet(
            opt.mu + 1, opt.wavenet_layers, opt.wavenet_blocks, opt.width, 256,
            256, opt.bottleneck, 1,
            1).to(self.devices[-1])  # opt.pool_length, opt.pool_length
        self.softmax = nn.LogSoftmax(dim=1)  # (1, 256, audio_len) -> pick 256

        if self.isTrain:
            self.A_target = torch.LongTensor([0] * opt.batch_size).to(
                self.devices[0])
            self.B_target = torch.LongTensor([1] * opt.batch_size).to(
                self.devices[0])
            self.criterionDC = nn.CrossEntropyLoss(reduction='mean')
            self.criterionDecode = nn.NLLLoss(reduction='mean')
            # self.optimizer_C = torch.optim.Adam(itertools.chain(self.netE.parameters(), self.netC.parameters()), lr=opt.lr)
            # self.optimizer_D = torch.optim.Adam(itertools.chain(self.netE.parameters(), self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr)
            self.optimizer_C = AdaBound(self.netC.parameters(),
                                        lr=opt.lr,
                                        final_lr=0.1)
            self.optimizer_D = AdaBound(itertools.chain(
                self.netE.parameters(), self.netD_A.parameters(),
                self.netD_B.parameters()),
                                        lr=opt.lr,
                                        final_lr=0.1)
            self.optimizers = [self.optimizer_C, self.optimizer_D]
        else:
            self.preprocesses = ['mulaw']
            # TODO change structure of test.py and setup() instead
            load_suffix = str(
                opt.load_iter) if opt.load_iter > 0 else opt.epoch
            self.load_networks(load_suffix)
            self.netC.eval()
            self.netD_A.eval()
            self.netD_B.eval()

            self.infer_A = NVWaveNet(**(self.netD_A.export_weights()))
            self.infer_B = NVWaveNet(**(self.netD_B.export_weights()))

    def set_input(self, input):
        A, params_A = input[0]
        B, params_B = input[1]
        self.aug_A = A.to(self.devices[0])
        self.aug_B = B.to(self.devices[0])
        self.real_A = params_A['original'].to(self.devices[0])
        self.real_B = params_B['original'].to(self.devices[0])

        self.params_A = self.decollate_params(params_A)
        self.params_B = self.decollate_params(params_B)

    def get_indices(self, y):
        y = (y + 1.) * .5 * self.opt.mu
        return y.long()

    # def to_onehot(self, y, device):
    #     y = self.get_indices(y).view(-1, 1)
    #     y = torch.zeros(y.size()[0], self.opt.mu + 1).to(device).scatter_(1, y, 1)
    #     return y.transpose(0, 1).unsqueeze(0)

    def inv_indices(self, y):
        return y.float() / self.opt.mu * 2. - 1.

    @staticmethod
    def sample(logits):
        dist = torch.distributions.categorical.Categorical(
            logits=logits.transpose(1, 2))
        return dist.sample()

    def train(self):
        self.optimizer_C.zero_grad()
        encoded_A = self.netE(
            self.real_A.unsqueeze(1))  # Input range: (-1, 1) Output: R^64
        pred_C_A = self.netC(
            encoded_A)  # (encoded_A + 1.) * self.vector_length / 2)
        self.loss_C_A_right = self.criterionDC(pred_C_A, self.A_target)
        loss = self.opt.dc_lambda * self.loss_C_A_right
        loss.backward()

        encoded_B = self.netE(self.real_B.unsqueeze(1))
        pred_C_B = self.netC(
            encoded_B)  # (encoded_B + 1.) * self.vector_length / 2)
        self.loss_C_B_right = self.criterionDC(pred_C_B, self.B_target)
        loss = self.opt.dc_lambda * self.loss_C_B_right
        loss.backward()
        self.optimizer_C.step()

        self.optimizer_D.zero_grad()
        encoded_A = self.netE(
            self.aug_A.unsqueeze(1))  # Input range: (-1, 1) Output: R^64
        pred_C_A = self.netC(encoded_A)
        self.loss_C_A_wrong = self.criterionDC(pred_C_A, self.A_target)
        encoded_A = nn.functional.interpolate(encoded_A,
                                              size=self.opt.audio_length).to(
                                                  self.devices[-1])
        real_A = self.get_indices(self.real_A).to(self.devices[-1])
        pred_D_A = self.netD_A((encoded_A, real_A))
        rec_A = self.softmax(pred_D_A)
        self.loss_D_A = self.criterionDecode(rec_A, real_A)
        loss = self.loss_D_A - self.opt.dc_lambda * self.loss_C_A_wrong
        loss.backward()

        encoded_B = self.netE(self.aug_B.unsqueeze(1))
        pred_C_B = self.netC(encoded_B)
        self.loss_C_B_wrong = self.criterionDC(pred_C_B, self.B_target)
        encoded_B = nn.functional.interpolate(encoded_B,
                                              size=self.opt.audio_length).to(
                                                  self.devices[-1])
        real_B = self.get_indices(self.real_B).to(self.devices[-1])
        pred_D_B = self.netD_B((encoded_B, real_B))
        rec_B = self.softmax(pred_D_B)
        self.loss_D_B = self.criterionDecode(rec_B, real_B)
        loss = self.loss_D_B - self.opt.dc_lambda * self.loss_C_B_wrong
        loss.backward()
        self.optimizer_D.step()

    def test(self):
        with torch.no_grad():
            encoded_A = self.netE(self.aug_A.unsqueeze(1))
            encoded_B = self.netE(self.aug_B.unsqueeze(1))
            # pred_C_A = self.softmax(self.netC(encoded_A))
            # pred_C_B = self.softmax(self.netC(encoded_B))
            encoded_A = nn.functional.interpolate(
                encoded_A, size=self.opt.audio_length).to(self.devices[-1])
            encoded_B = nn.functional.interpolate(
                encoded_B, size=self.opt.audio_length).to(self.devices[-1])
            self.fake_A = self.infer_A.infer(
                self.netD_A.get_cond_input(encoded_B), Impl.AUTO)
            self.fake_B = self.infer_B.infer(
                self.netD_B.get_cond_input(encoded_A), Impl.AUTO)
            self.fake_A = self.inv_indices(self.fake_A)
            self.fake_B = self.inv_indices(self.fake_B)