def train(self, loaders, num_epochs=15):
     '''
     Wrapper method for training on training set + evaluation on validation set.
     '''
     self.make_output_dir()
     scheduler = StepLR(self.optimizer, step_size=12, gamma=0.1)
     torch.manual_seed(0)
     train_loader = loaders[0]
     valid_loader = loaders[1]
     print("Starting Training")
     for epoch in range(num_epochs):
         tr_metrics, train_loss = self.train_epoch(train_loader,
                                                   optimizer=self.optimizer,
                                                   scheduler=scheduler)
         val_metrics, valid_loss, pred, GT = self.valid_epoch(valid_loader,
                                                              epoch=epoch)
         self.visualize_and_save(pred, GT, epoch)
         self.save_state_dicts(epoch)
         self.write_results(epoch)
         scheduler.step()
         print("Current Learning Rate is ", scheduler.get_last_lr())
         # For logging
         self.train_losses.append(train_loss)
         self.valid_losses.append(valid_loss)
         self.tr_dice.append(tr_metrics[0])
         self.tr_jacc.append(tr_metrics[1])
         self.val_dice.append(val_metrics[0])
         self.val_jacc.append(val_metrics[1])
         print(
             'Epoch {}: \t train_dice: {:.2f} \t train_jacc: {:.2f} \t val_dice: {:.2f} \t val_jacc: {:.2f} \t train_loss: {:.2f} \t'
             ' valid_loss: {:.2f}'.format(epoch, tr_metrics[0],
                                          tr_metrics[1], val_metrics[0],
                                          val_metrics[1], train_loss,
                                          valid_loss))
     self.plot_and_save()
示例#2
0
    def exponentialWarmUp(self,
                          w,
                          batch_size,
                          num_iters=1000,
                          lr_lims=None,
                          range_test=False,
                          explosion_ratio=None):
        """
        Network warm up with exponentially increasing LR from some small value
        If `range_test` is `True`, then test on more relevant LR will be
        run instead

        With `explosion_ratio`, you can reduce the method's execution time
        by setting the argument's value close to one. It limits the maximum
        possible ration between the total loss at some iteration and
        at the first step
        """
        loss_logs = w.new_empty([num_iters, w.nelement()])
        opt_dict = copy.deepcopy(self.optimizer.state_dict())
        if range_test:
            net_weights = copy.deepcopy(self.net.state_dict())

        if lr_lims is not None:
            lr_0, lr_n = lr_lims
            assert lr_n > lr_0 and lr_0 > 0, 'Inappropriate range'
            gamma = (lr_n / lr_0)**(1. / num_iters)
        else:
            lr_n = self.scheduler.base_lrs[0]
            assert lr_n != 0, 'Zero optimizer lr'
            lr_0 = lr_n * 1e-5
            gamma = (1e5)**(1. / num_iters)

        for par in self.optimizer.param_groups:
            par['lr'] = lr_0
        scheduler = StepLR(self.optimizer, 1, gamma)
        xdata = np.empty(num_iters, 'f4')

        desc = 'Range test' if range_test else 'Warm-up'
        for i in trange(num_iters, desc=desc):
            self.optimizer.zero_grad()
            xdata[i] = scheduler.get_last_lr()[0]

            batch = self.pde.sampleBatch(batch_size)
            L = self.pde.computeLoss(batch, self.net)
            loss_logs[i] = L.data
            (w @ L).backward()

            self.optimizer.step()
            scheduler.step()
            if (explosion_ratio is not None
                    and explosion_ratio * (w @ loss_logs[0]).item() <
                (w @ L).item()):
                print('Early stop due to the loss explosion')
                break
        self.optimizer.load_state_dict(opt_dict)
        if range_test:
            self.net.load_state_dict(net_weights)
            args = ('expo', explosion_ratio, w, i)
            displayRTresults(xdata, loss_logs, *args)
示例#3
0
class LinearStepLR:
    def __init__(self, optimizer, init_lr, epoch, eta_min, decay_rate):
        n = int((log(eta_min) - log(init_lr)) / log(decay_rate)) + 1
        step_size = int(epoch / n)
        self.scheduler = StepLR(optimizer=optimizer,
                                gamma=decay_rate,
                                step_size=step_size)

    def get_last_lr(self):
        return self.scheduler.get_last_lr()

    def step(self):
        self.scheduler.step()
def train(i, model, trainloader, testloader, optimizer):
    train_acc = []
    test_acc = []
    scheduler = StepLR(optimizer, step_size=5, gamma=0.9)
    for epoch in range(i):
        s = time.time()
        train_acc = train_step(epoch, train_acc, model, trainloader, optimizer)
        test_acc, _ = test(test_acc, model, testloader)
        scheduler.step()
        e = time.time()
        print('This epoch took', e - s, 'seconds to train')
        print('Current learning rate: ', scheduler.get_last_lr()[0])
    print('Best training accuracy overall: ', max(test_acc))
    return train_acc, test_acc
def Train(args):
    train_path = args['trainset_path']
    dev_path = args['testset_path']
    resume = args['resume']
    checkpoint_path = args['checkpoint_path']
    history_path = args['history_path']
    log_path = args['log_path']
    vocab_path = args['vocab_path']
    model_name = args['model_save_name']
    model_resume_name = args['model_resume_name']
    batch_size = args['batch_size']
    end_epoch = args['end_epoch']
    lr = args['lr']
    loss_check_freq = args['loss_check']
    check_steps = args['check_steps']
    save_steps = args['save_steps']
    os.environ['CUDA_VISIBLE_DEVICES'] = args['GPU_ids']
    embed_path = args['embed_path']
    embed_dim = args['embed_dim']
    nheads = args['nheads_transformer']
    #########
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    if not os.path.exists(history_path):
        os.makedirs(history_path)
    log_save_name = 'log_' + model_name + '.log'
    reset_log(log_path + log_save_name)
    logger = logging.getLogger(__name__)
    for k, v in args.items():
        logger.info(k + ':' + str(v))
    checkpoint_name = os.path.join(checkpoint_path,
                                   model_name + '_best_ckpt.pth')
    model_ckpt_name = os.path.join(checkpoint_path, model_name + '_best.pkl')

    if not model_resume_name:
        model_resume_name = model_ckpt_name
    localtime = time.asctime(time.localtime(time.time()))
    logger.info('#####start time:%s' % (localtime))
    time_stamp = int(time.time())
    logger.info('time stamp:%d' % (time_stamp))
    logger.info('######Model: %s' % (model_name))
    logger.info('trainset path :%s' % (train_path))
    logger.info('valset path: %s' % (dev_path))
    logger.info('batch_size:%d' % (batch_size))
    logger.info('learning rate:%f' % (lr))
    logger.info('end epoch:%d' % (end_epoch))

    tokenizer = basic_tokenizer(vocab_path)
    trainset = dialogue_dataset(train_path, tokenizer)
    devset = dialogue_dataset(dev_path, tokenizer)

    print("训练集样本数:%d" % (trainset.__len__()))
    logger.info("训练集样本数:%d" % (trainset.__len__()))
    print("验证集样本数:%d" % (devset.__len__()))
    logger.info("验证集样本数:%d" % (devset.__len__()))

    train_loader = DataLoader(trainset,
                              batch_size=batch_size,
                              num_workers=4,
                              shuffle=True,
                              collate_fn=collate_func,
                              drop_last=True)
    dev_loader = DataLoader(devset,
                            batch_size=batch_size,
                            num_workers=4,
                            shuffle=False,
                            collate_fn=collate_func)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = transformer_base(tokenizer.vocab_size, embed_dim, nheads,
                             embed_path)
    model.to(device)

    if resume != 0:
        logger.info('Resuming from checkpoint...')
        model.load_state_dict(torch.load(model_resume_name))
        checkpoint = torch.load(checkpoint_name)
        best_loss = checkpoint['loss']
        start_epoch = checkpoint['epoch']
        history = checkpoint['history']
    else:
        best_loss = math.inf
        start_epoch = -1
        history = {'train_loss': [], 'val_loss': []}

    criterion = seq_generation_loss(device=device).to(device)
    optim = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                    model.parameters()),
                             lr=lr)  #weight_decay=1e-5
    scheduler = StepLR(optim, step_size=5, gamma=0.9)

    steps_cnt = 0
    for epoch in range(start_epoch + 1, end_epoch):
        print('-------------epoch:%d--------------' % (epoch))
        logger.info('-------------epoch:%d--------------' % (epoch))
        model.train()
        loss_tr = 0
        local_steps_cnt = 0
        #########   train ###########
        print('start training!')
        for batch_idx, batch in tqdm(
                enumerate(train_loader),
                total=int(len(train_loader.dataset) / batch_size) + 1):
            src_batch,tgt_batch,src_pad_batch,tgt_pad_batch,tgt_mask_batch=batch['src_ids'], \
                                    batch['tgt_ids'],batch['src_pad_mask'],batch['tgt_pad_mask'],batch['tgt_mask']

            src_batch = src_batch.to(device)
            tgt_batch = tgt_batch.to(device)
            src_pad_batch = src_pad_batch.to(device)
            tgt_pad_batch = tgt_pad_batch.to(device)
            tgt_mask_batch = tgt_mask_batch.to(device)

            model.zero_grad()
            out = model(src_batch, tgt_batch, src_pad_batch, tgt_pad_batch,
                        tgt_mask_batch)
            loss = criterion(out, tgt_batch)
            loss.backward()  # compute gradients
            optim.step()  # update parameters
            steps_cnt += 1
            local_steps_cnt += 1
            loss_tr += loss.item()

            if batch_idx % loss_check_freq == 0:
                print('batch:%d' % (batch_idx))
                print('loss:%f' % (loss.item()))

            if steps_cnt % check_steps == 0:
                loss_tr /= local_steps_cnt
                print('trainset loss:%f' % (loss_tr))
                logger.info('trainset loss:%f' % (loss_tr))
                history['train_loss'].append(loss_tr)
                loss_tr = 0
                local_steps_cnt = 0
                #########  val ############
                loss_va = 0
                model.eval()
                with torch.no_grad():
                    print('start validating!')
                    for batch_idx, batch in tqdm(
                            enumerate(dev_loader),
                            total=int(len(dev_loader.dataset) / batch_size) +
                            1):
                        src_batch,tgt_batch,src_pad_batch,tgt_pad_batch,tgt_mask_batch=batch['src_ids'], \
                                                batch['tgt_ids'],batch['src_pad_mask'],batch['tgt_pad_mask'],batch['tgt_mask']

                        src_batch = src_batch.to(device)
                        tgt_batch = tgt_batch.to(device)
                        src_pad_batch = src_pad_batch.to(device)
                        tgt_pad_batch = tgt_pad_batch.to(device)
                        tgt_mask_batch = tgt_mask_batch.to(device)

                        model.zero_grad()
                        out = model(src_batch, tgt_batch, src_pad_batch,
                                    tgt_pad_batch, tgt_mask_batch)
                        loss = criterion(out, tgt_batch)
                        loss_va += loss.item()

                    loss_va = loss_va / (batch_idx + 1)
                    print('valset loss:%f' % (loss_va))
                    logger.info('valset loss:%f' % (loss_va))
                    history['val_loss'].append(loss_va)

                    # save checkpoint and model
                    if loss_va < best_loss:
                        logger.info('Checkpoint Saving...')
                        print('best loss so far! Checkpoint Saving...')
                        state = {
                            'epoch': epoch,
                            'loss': loss_va,
                            'history': history
                        }
                        torch.save(state, checkpoint_name)
                        best_loss = loss_va
                        ## save model
                        torch.save(model.state_dict(), model_ckpt_name)
                scheduler.step()
                logger.info("current lr:%f" % (scheduler.get_last_lr()[0]))
                model.train()
            if steps_cnt % save_steps == 0:
                logger.info('match save steps,Checkpoint Saving...')
                torch.save(
                    model.state_dict(),
                    os.path.join(
                        checkpoint_path,
                        model_name + '_steps_' + str(steps_cnt) + '.pkl'))
def train(hyp, args, device, train_loader, test_loader, tb_writer=None):

    init_seeds()

    model = Net().to(device)
    optimizer = optim.Adam(model.parameters(),
                           lr=hyp['lr'],
                           betas=(hyp['momentum'], 0.999))
    scheduler = StepLR(optimizer, step_size=1, gamma=hyp['gamma'])

    log_dir = Path(tb_writer.log_dir) if tb_writer else Path(
        args.logdir) / 'evolve'  # logging directory
    results_file = str(log_dir / 'results.txt')
    wdir = log_dir / 'weights'  # weights directory
    os.makedirs(wdir, exist_ok=True)
    last = wdir / 'last.pt'
    best = wdir / 'best.pt'

    # Save run settings
    with open(log_dir / 'hyp.yaml', 'w+') as f:
        yaml.dump(hyp, f, sort_keys=False)
    with open(log_dir / 'opt.yaml', 'w+') as f:
        yaml.dump(vars(args), f, sort_keys=False)

    model.train()

    best_fitness = 0.0
    for epoch in range(1, args.epochs + 1):

        final_epoch = epoch + 1 == args.epochs

        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()
            if batch_idx % args.log_interval == 0:
                print(
                    'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tLR: {}'.
                    format(epoch, batch_idx * len(data),
                           len(train_loader.dataset),
                           100. * batch_idx / len(train_loader), loss.item(),
                           scheduler.get_last_lr()))
                if args.dry_run:
                    break
        fitness = test(model, device, test_loader)
        if fitness > best_fitness:
            best_fitness = fitness
        scheduler.step()

        if tb_writer:
            tags = ['train/loss', 'test/accuracy(%)']  # params
            for x, tag in zip([loss.item(), fitness], tags):
                tb_writer.add_scalar(tag, x, epoch)

        if args.save_model:
            with open(results_file, 'r') as f:  # create checkpoint
                ckpt = {
                    'epoch': epoch,
                    'best_fitness': best_fitness,
                    'training_results': f.read(),
                    'model': model.state_dict(),
                    'optimizer':
                    None if final_epoch else optimizer.state_dict()
                }
            torch.save(model.state_dict(), "mnist_cnn.pt")
            # Save last, best and delete
            torch.save(ckpt, last)
            if best_fitness == fitness:
                torch.save(ckpt, best)

    return fitness
class Runner(object):
    def __init__(self):
        self.best_accuracy = 0.0

        # all data
        self.data_train = MiniImageNetDataset.get_data_all(Config.data_root)
        self.task_train = MiniImageNetDataset(self.data_train, Config.num_way,
                                              Config.num_shot)
        self.task_train_loader = DataLoader(self.task_train,
                                            Config.batch_size,
                                            True,
                                            num_workers=Config.num_workers)

        # IC
        self.produce_class = ProduceClass(len(self.data_train),
                                          Config.ic_out_dim, Config.ic_ratio)
        self.produce_class.init()
        self.task_train.set_samples_class(self.produce_class.classes)

        # model
        self.feature_encoder = RunnerTool.to_cuda(Config.feature_encoder)
        self.relation_network = RunnerTool.to_cuda(Config.relation_network)
        self.ic_model = RunnerTool.to_cuda(ICResNet(low_dim=Config.ic_out_dim))

        RunnerTool.to_cuda(self.feature_encoder.apply(RunnerTool.weights_init))
        RunnerTool.to_cuda(self.relation_network.apply(
            RunnerTool.weights_init))
        RunnerTool.to_cuda(self.ic_model.apply(RunnerTool.weights_init))

        # optim
        self.feature_encoder_optim = torch.optim.Adam(
            self.feature_encoder.parameters(), lr=Config.learning_rate)
        self.relation_network_optim = torch.optim.Adam(
            self.relation_network.parameters(), lr=Config.learning_rate)
        self.ic_model_optim = torch.optim.Adam(self.ic_model.parameters(),
                                               lr=Config.learning_rate)

        self.feature_encoder_scheduler = StepLR(self.feature_encoder_optim,
                                                Config.train_epoch // 3,
                                                gamma=0.5)
        self.relation_network_scheduler = StepLR(self.relation_network_optim,
                                                 Config.train_epoch // 3,
                                                 gamma=0.5)
        self.ic_model_scheduler = StepLR(self.ic_model_optim,
                                         Config.train_epoch // 3,
                                         gamma=0.5)

        # loss
        self.ic_loss = RunnerTool.to_cuda(nn.CrossEntropyLoss())
        self.fsl_loss = RunnerTool.to_cuda(nn.MSELoss())

        # Eval
        self.test_tool_fsl = TestTool(self.compare_fsl_test,
                                      data_root=Config.data_root,
                                      num_way=Config.num_way,
                                      num_shot=Config.num_shot,
                                      episode_size=Config.episode_size,
                                      test_episode=Config.test_episode,
                                      transform=self.task_train.transform_test)
        self.test_tool_ic = ICTestTool(feature_encoder=None,
                                       ic_model=self.ic_model,
                                       data_root=Config.data_root,
                                       batch_size=Config.batch_size,
                                       num_workers=Config.num_workers,
                                       ic_out_dim=Config.ic_out_dim)
        pass

    def load_model(self):
        if os.path.exists(Config.fe_dir):
            self.feature_encoder.load_state_dict(torch.load(Config.fe_dir))
            Tools.print("load feature encoder success from {}".format(
                Config.fe_dir))

        if os.path.exists(Config.rn_dir):
            self.relation_network.load_state_dict(torch.load(Config.rn_dir))
            Tools.print("load relation network success from {}".format(
                Config.rn_dir))

        if os.path.exists(Config.ic_dir):
            self.ic_model.load_state_dict(torch.load(Config.ic_dir))
            Tools.print("load ic model success from {}".format(Config.ic_dir))
        pass

    def compare_fsl(self, task_data):
        data_batch_size, data_image_num, data_num_channel, data_width, data_weight = task_data.shape
        fe_inputs = task_data.view(
            [-1, data_num_channel, data_width, data_weight])  # 90, 3, 84, 84

        # feature encoder
        data_features = self.feature_encoder(fe_inputs)  # 90x64*19*19
        _, feature_dim, feature_width, feature_height = data_features.shape

        # calculate
        data_features = data_features.view(
            [-1, data_image_num, feature_dim, feature_width, feature_height])
        data_features_support, data_features_query = data_features.split(
            Config.num_shot * Config.num_way, dim=1)
        data_features_query_repeat = data_features_query.repeat(
            1, Config.num_shot * Config.num_way, 1, 1, 1)

        # calculate relations
        relation_pairs = torch.cat(
            (data_features_support, data_features_query_repeat), 2)
        relation_pairs = relation_pairs.view(-1, feature_dim * 2,
                                             feature_width, feature_height)

        relations = self.relation_network(relation_pairs)
        relations = relations.view(-1, Config.num_way * Config.num_shot)

        return relations

    def compare_fsl_test(self, samples, batches):
        # calculate features
        sample_features = self.feature_encoder(samples)  # 5x64*19*19
        batch_features = self.feature_encoder(batches)  # 75x64*19*19
        batch_size, feature_dim, feature_width, feature_height = batch_features.shape

        # calculate
        sample_features_ext = sample_features.unsqueeze(0).repeat(
            batch_size, 1, 1, 1, 1)
        batch_features_ext = batch_features.unsqueeze(1).repeat(
            1, Config.num_shot * Config.num_way, 1, 1, 1)

        # calculate relations
        relation_pairs = torch.cat((sample_features_ext, batch_features_ext),
                                   2)
        relation_pairs = relation_pairs.view(-1, feature_dim * 2,
                                             feature_width, feature_height)

        relations = self.relation_network(relation_pairs)
        relations = relations.view(-1, Config.num_way * Config.num_shot)
        return relations

    def train(self):
        Tools.print()
        Tools.print("Training...")

        # Init Update
        try:
            self.feature_encoder.eval()
            self.relation_network.eval()
            self.ic_model.eval()
            Tools.print("Init label {} .......")
            self.produce_class.reset()
            for task_data, task_labels, task_index in tqdm(
                    self.task_train_loader):
                ic_labels = RunnerTool.to_cuda(task_index[:, -1])
                task_data, task_labels = RunnerTool.to_cuda(
                    task_data), RunnerTool.to_cuda(task_labels)
                ic_out_logits, ic_out_l2norm = self.ic_model(task_data[:, -1])
                self.produce_class.cal_label(ic_out_l2norm, ic_labels)
                pass
            Tools.print("Epoch: {}/{}".format(self.produce_class.count,
                                              self.produce_class.count_2))
        finally:
            pass

        for epoch in range(Config.train_epoch):
            self.feature_encoder.train()
            self.relation_network.train()
            self.ic_model.train()

            Tools.print()
            self.produce_class.reset()
            Tools.print(self.task_train.classes)
            all_loss, all_loss_fsl, all_loss_ic = 0.0, 0.0, 0.0
            for task_data, task_labels, task_index in tqdm(
                    self.task_train_loader):
                ic_labels = RunnerTool.to_cuda(task_index[:, -1])
                task_data, task_labels = RunnerTool.to_cuda(
                    task_data), RunnerTool.to_cuda(task_labels)

                ###########################################################################
                # 1 calculate features
                relations = self.compare_fsl(task_data)
                ic_out_logits, ic_out_l2norm = self.ic_model(task_data[:, -1])

                # 2
                ic_targets = self.produce_class.get_label(ic_labels)
                self.produce_class.cal_label(ic_out_l2norm, ic_labels)

                # 3 loss
                loss_fsl = self.fsl_loss(relations,
                                         task_labels) * Config.loss_fsl_ratio
                loss_ic = self.ic_loss(ic_out_logits,
                                       ic_targets) * Config.loss_ic_ratio
                loss = loss_fsl + loss_ic
                all_loss += loss.item()
                all_loss_fsl += loss_fsl.item()
                all_loss_ic += loss_ic.item()

                # 4 backward
                self.feature_encoder.zero_grad()
                self.relation_network.zero_grad()
                self.ic_model.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    self.feature_encoder.parameters(), 0.5)
                torch.nn.utils.clip_grad_norm_(
                    self.relation_network.parameters(), 0.5)
                torch.nn.utils.clip_grad_norm_(self.ic_model.parameters(), 0.5)
                self.feature_encoder_optim.step()
                self.relation_network_optim.step()
                self.ic_model_optim.step()
                ###########################################################################
                pass

            ###########################################################################
            # print
            Tools.print("{:6} loss:{:.3f} fsl:{:.3f} ic:{:.3f} lr:{}".format(
                epoch + 1, all_loss / len(self.task_train_loader),
                all_loss_fsl / len(self.task_train_loader),
                all_loss_ic / len(self.task_train_loader),
                self.feature_encoder_scheduler.get_last_lr()))
            Tools.print("Train: [{}] {}/{}".format(epoch,
                                                   self.produce_class.count,
                                                   self.produce_class.count_2))
            self.feature_encoder_scheduler.step()
            self.relation_network_scheduler.step()
            self.ic_model_scheduler.step()
            ###########################################################################

            ###########################################################################
            # Val
            if epoch % Config.val_freq == 0:
                self.feature_encoder.eval()
                self.relation_network.eval()
                self.ic_model.eval()

                self.test_tool_ic.val(epoch=epoch)
                val_accuracy = self.test_tool_fsl.val(episode=epoch,
                                                      is_print=True)

                if val_accuracy > self.best_accuracy:
                    self.best_accuracy = val_accuracy
                    torch.save(self.feature_encoder.state_dict(),
                               Config.fe_dir)
                    torch.save(self.relation_network.state_dict(),
                               Config.rn_dir)
                    torch.save(self.ic_model.state_dict(), Config.ic_dir)
                    Tools.print("Save networks for epoch: {}".format(epoch))
                    pass
                pass
            ###########################################################################
            pass

        pass

    pass
class Runner(object):
    def __init__(self):
        self.best_accuracy = 0.0

        # all data
        self.data_train = MiniImageNetDataset.get_data_all(Config.data_root)
        self.task_train = MiniImageNetDataset(self.data_train, Config.num_way,
                                              Config.num_shot)
        self.task_train_loader = DataLoader(self.task_train,
                                            Config.batch_size,
                                            shuffle=True,
                                            num_workers=Config.num_workers)

        # model
        self.feature_encoder = RunnerTool.to_cuda(Config.feature_encoder)
        self.relation_network = RunnerTool.to_cuda(Config.relation_network)
        RunnerTool.to_cuda(self.feature_encoder.apply(RunnerTool.weights_init))
        RunnerTool.to_cuda(self.relation_network.apply(
            RunnerTool.weights_init))

        # optim
        self.feature_encoder_optim = torch.optim.Adam(
            self.feature_encoder.parameters(), lr=Config.learning_rate)
        self.feature_encoder_scheduler = StepLR(self.feature_encoder_optim,
                                                Config.train_epoch // 3,
                                                gamma=0.5)
        self.relation_network_optim = torch.optim.Adam(
            self.relation_network.parameters(), lr=Config.learning_rate)
        self.relation_network_scheduler = StepLR(self.relation_network_optim,
                                                 Config.train_epoch // 3,
                                                 gamma=0.5)

        # loss
        self.loss = RunnerTool.to_cuda(nn.MSELoss())

        self.test_tool = TestTool(self.compare_fsl_test,
                                  data_root=Config.data_root,
                                  num_way=Config.num_way,
                                  num_shot=Config.num_shot,
                                  episode_size=Config.episode_size,
                                  test_episode=Config.test_episode,
                                  transform=self.task_train.transform_test)
        pass

    def load_model(self):
        if os.path.exists(Config.fe_dir):
            self.feature_encoder.load_state_dict(torch.load(Config.fe_dir))
            Tools.print("load feature encoder success from {}".format(
                Config.fe_dir))

        if os.path.exists(Config.rn_dir):
            self.relation_network.load_state_dict(torch.load(Config.rn_dir))
            Tools.print("load relation network success from {}".format(
                Config.rn_dir))
        pass

    def compare_fsl(self, task_data):
        data_batch_size, data_image_num, data_num_channel, data_width, data_weight = task_data.shape
        fe_inputs = task_data.view(
            [-1, data_num_channel, data_width, data_weight])  # 90, 3, 84, 84

        # feature encoder
        data_features = self.feature_encoder(fe_inputs)  # 90x64*19*19
        _, feature_dim, feature_width, feature_height = data_features.shape

        # calculate
        data_features = data_features.view(
            [-1, data_image_num, feature_dim, feature_width, feature_height])
        data_features_support, data_features_query = data_features.split(
            Config.num_shot * Config.num_way, dim=1)
        data_features_query_repeat = data_features_query.repeat(
            1, Config.num_shot * Config.num_way, 1, 1, 1)

        # calculate relations
        relation_pairs = torch.cat(
            (data_features_support, data_features_query_repeat), 2)
        relation_pairs = relation_pairs.view(-1, feature_dim * 2,
                                             feature_width, feature_height)

        relations = self.relation_network(relation_pairs)
        relations = relations.view(-1, Config.num_way * Config.num_shot)
        return relations

    def compare_fsl_test(self, samples, batches):
        # calculate features
        sample_features = self.feature_encoder(samples)  # 5x64*19*19
        batch_features = self.feature_encoder(batches)  # 75x64*19*19
        batch_size, feature_dim, feature_width, feature_height = batch_features.shape

        # calculate
        sample_features_ext = sample_features.unsqueeze(0).repeat(
            batch_size, 1, 1, 1, 1)
        batch_features_ext = batch_features.unsqueeze(1).repeat(
            1, Config.num_shot * Config.num_way, 1, 1, 1)

        # calculate relations
        relation_pairs = torch.cat((sample_features_ext, batch_features_ext),
                                   2)
        relation_pairs = relation_pairs.view(-1, feature_dim * 2,
                                             feature_width, feature_height)

        relations = self.relation_network(relation_pairs)
        relations = relations.view(-1, Config.num_way * Config.num_shot)
        return relations

    def train(self):
        Tools.print()
        Tools.print("Training...")

        for epoch in range(Config.train_epoch):
            self.feature_encoder.train()
            self.relation_network.train()

            Tools.print()
            all_loss = 0.0
            for task_data, task_labels, task_index in tqdm(
                    self.task_train_loader):
                task_data, task_labels = RunnerTool.to_cuda(
                    task_data), RunnerTool.to_cuda(task_labels)

                # 1 calculate features
                relations = self.compare_fsl(task_data)

                # 2 loss
                loss = self.loss(relations, task_labels)
                all_loss += loss.item()

                # 3 backward
                self.feature_encoder.zero_grad()
                self.relation_network.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    self.feature_encoder.parameters(), 0.5)
                torch.nn.utils.clip_grad_norm_(
                    self.relation_network.parameters(), 0.5)
                self.feature_encoder_optim.step()
                self.relation_network_optim.step()
                ###########################################################################
                pass

            ###########################################################################
            # print
            Tools.print("{:6} loss:{:.3f} lr:{}".format(
                epoch + 1, all_loss / len(self.task_train_loader),
                self.feature_encoder_scheduler.get_last_lr()))

            self.feature_encoder_scheduler.step()
            self.relation_network_scheduler.step()
            ###########################################################################

            ###########################################################################
            # Val
            if epoch % Config.val_freq == 0:
                Tools.print()
                Tools.print("Test {} {} .......".format(
                    epoch, Config.model_name))
                self.feature_encoder.eval()
                self.relation_network.eval()

                val_accuracy = self.test_tool.val(episode=epoch, is_print=True)
                if val_accuracy > self.best_accuracy:
                    self.best_accuracy = val_accuracy
                    torch.save(self.feature_encoder.state_dict(),
                               Config.fe_dir)
                    torch.save(self.relation_network.state_dict(),
                               Config.rn_dir)
                    Tools.print("Save networks for epoch: {}".format(epoch))
                    pass
                pass
            ###########################################################################
            pass

        pass

    pass
class Runner(object):

    def __init__(self):
        self.best_accuracy = 0.0

        # all data
        self.data_train = MiniImageNetDataset.get_data_all(Config.data_root)
        self.task_train = MiniImageNetDataset(self.data_train, Config.num_way, Config.num_shot)
        self.task_train_loader = DataLoader(self.task_train, Config.batch_size,
                                            shuffle=True, num_workers=Config.num_workers)

        # model
        self.proto_net = RunnerTool.to_cuda(Config.proto_net)
        RunnerTool.to_cuda(self.proto_net.apply(RunnerTool.weights_init))

        # optim
        self.proto_net_optim = torch.optim.SGD(self.proto_net.parameters(),
                                               lr=Config.learning_rate, momentum=0.9, weight_decay=5e-4)
        self.proto_net_scheduler = StepLR(self.proto_net_optim, Config.train_epoch // 2, gamma=0.1)

        self.test_tool = TestTool(self.proto_test, data_root=Config.data_root,
                                  num_way=Config.num_way,  num_shot=Config.num_shot,
                                  episode_size=Config.episode_size, test_episode=Config.test_episode,
                                  transform=self.task_train.transform_test)
        pass

    def load_model(self):
        if os.path.exists(Config.pn_pretrain):
            self.proto_net.load_state_dict(torch.load(Config.pn_pretrain))
            Tools.print("load proto net success from {}".format(Config.pn_pretrain))
            pass

        if os.path.exists(Config.pn_dir):
            self.proto_net.load_state_dict(torch.load(Config.pn_dir))
            Tools.print("load proto net success from {}".format(Config.pn_dir))
            pass
        pass

    def proto(self, task_data):
        data_batch_size, data_image_num, data_num_channel, data_width, data_weight = task_data.shape
        data_x = task_data.view(-1, data_num_channel, data_width, data_weight)
        net_out = self.proto_net(data_x)
        z = net_out.view(data_batch_size, data_image_num, -1)

        z_support, z_query = z.split(Config.num_shot * Config.num_way, dim=1)
        z_batch_size, z_num, z_dim = z_support.shape
        z_support = z_support.view(z_batch_size, Config.num_way, Config.num_shot, z_dim)

        z_support_proto = z_support.mean(2)
        z_query_expand = z_query.expand(z_batch_size, Config.num_way, z_dim)

        dists = torch.pow(z_query_expand - z_support_proto, 2).sum(2)
        log_p_y = F.log_softmax(-dists, dim=1)
        return log_p_y

    def proto_test(self, samples, batches):
        batch_num, _, _, _ = batches.shape

        sample_z = self.proto_net(samples)  # 5x64*5*5
        batch_z = self.proto_net(batches)  # 75x64*5*5
        sample_z = sample_z.view(Config.num_way, Config.num_shot, -1)
        batch_z = batch_z.view(batch_num, -1)
        _, z_dim = batch_z.shape

        z_proto = sample_z.mean(1)
        z_proto_expand = z_proto.unsqueeze(0).expand(batch_num, Config.num_way, z_dim)
        z_query_expand = batch_z.unsqueeze(1).expand(batch_num, Config.num_way, z_dim)

        dists = torch.pow(z_query_expand - z_proto_expand, 2).sum(2)
        log_p_y = F.log_softmax(-dists, dim=1)
        return log_p_y

    def train(self):
        Tools.print()
        Tools.print("Training...")

        for epoch in range(Config.train_epoch):
            self.proto_net.train()

            Tools.print()
            all_loss = 0.0
            for task_data, task_labels, task_index in tqdm(self.task_train_loader):
                task_data, task_labels = RunnerTool.to_cuda(task_data), RunnerTool.to_cuda(task_labels)

                # 1 calculate features
                log_p_y = self.proto(task_data)

                # 2 loss
                loss = -(log_p_y * task_labels).sum() / task_labels.sum()
                all_loss += loss.item()

                # 3 backward
                self.proto_net.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.proto_net.parameters(), 0.5)
                self.proto_net_optim.step()
                ###########################################################################
                pass

            ###########################################################################
            # print
            Tools.print("{:6} loss:{:.3f} lr:{}".format(
                epoch + 1, all_loss / len(self.task_train_loader), self.proto_net_scheduler.get_last_lr()))

            self.proto_net_scheduler.step()
            ###########################################################################

            ###########################################################################
            # Val
            if epoch % Config.val_freq == 0:
                Tools.print()
                Tools.print("Test {} {} .......".format(epoch, Config.model_name))

                self.proto_net.eval()
                val_accuracy = self.test_tool.val(episode=epoch, is_print=True)
                if val_accuracy > self.best_accuracy:
                    self.best_accuracy = val_accuracy
                    torch.save(self.proto_net.state_dict(), Config.pn_dir)
                    Tools.print("Save networks for epoch: {}".format(epoch))
                    pass
                pass
            ###########################################################################
            pass

        pass

    pass
示例#10
0
    def fit(self, x, y, n_epoch=30, valid_x=None, valid_y=None):
        '''
        如果在实际应用中没有标签, 可以把大多数样本当做正常样本, 即把y全部设置为0
        :param x:
        :param y:
        :param n_epoch:
        :param valid_x:
        :param valid_y:
        :return:
        '''
        # todo missing injection
        self._vae.train()
        train_dataset = TsDataset(x, y)
        train_iter = torch.utils.data.DataLoader(train_dataset,
                                                 batch_size=256,
                                                 shuffle=True,
                                                 num_workers=0)

        if valid_x is not None:
            valid_dataset = TsDataset(valid_x, valid_y)
            valid_iter = torch.utils.data.DataLoader(valid_dataset,
                                                     batch_size=128,
                                                     shuffle=False,
                                                     num_workers=0)

        optimizer = self.optimizer  # todo 动态学习率
        lr_scheduler = StepLR(optimizer, step_size=100, gamma=0.75)
        for epoch in range(n_epoch):
            lr_scheduler.step()
            for train_x, train_y in train_iter:
                optimizer.zero_grad()
                z, x_miu, x_std, z_miu, z_std = self._vae(train_x)  # 前向传播
                l = self.m_elbo_loss(train_x, train_y, z, x_miu, x_std, z_miu,
                                     z_std)
                l.backward()
                optimizer.step()

            # 保存模型
            if epoch % 100 == 0:
                print("保存模型")
                torch.save(
                    {
                        "state_dict": self._vae.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "loss": l.item()
                    }, "./model_parameters/epoch{}-loss{:.2f}.tar".format(
                        epoch, l.item()))

            # 验证集
            if epoch % 50 == 0 and valid_x is not None:
                with torch.no_grad():
                    flag = 1
                    for v_x, v_y in valid_iter:
                        z, x_miu, x_std, z_miu, z_std = self._vae(v_x)
                        v_l = self.m_elbo_loss(v_x, v_y, z, x_miu, x_std,
                                               z_miu, z_std)
                        if flag == 1:
                            flag = 0
                            v_x_ = v_x[0].view(1, 120)
                            z, x_miu, x_std, z_miu, z_std = self._vae(v_x_)
                            restruct_compare_plot(v_x_.view(120),
                                                  x_miu.view(120))
                    print("train loss %.4f,  valid loss %.4f" %
                          (l.item(), v_l.item()))
                    with open("log.txt", "a") as f:
                        f.writelines("%d %.4f %.4f\n" %
                                     (epoch, l.item(), v_l.item()))
            else:
                print("loss", l.item(), "  lr,", lr_scheduler.get_last_lr())
class Runner(object):
    def __init__(self):
        self.best_accuracy = 0.0

        # all data
        self.data_train = MiniImageNetDataset.get_data_all(Config.data_root)
        self.task_train = MiniImageNetDataset(self.data_train)
        self.task_train_loader = DataLoader(self.task_train,
                                            Config.batch_size,
                                            shuffle=True,
                                            num_workers=Config.num_workers)

        # model
        self.proto_net = RunnerTool.to_cuda(Config.proto_net)
        self.norm = RunnerTool.to_cuda(Normalize())
        RunnerTool.to_cuda(self.proto_net.apply(RunnerTool.weights_init))

        # optim
        self.proto_net_optim = torch.optim.Adam(self.proto_net.parameters(),
                                                lr=Config.learning_rate)
        self.proto_net_scheduler = StepLR(self.proto_net_optim,
                                          Config.train_epoch // 3,
                                          gamma=0.5)

        self.test_tool = TestTool(self.proto_test,
                                  data_root=Config.data_root,
                                  num_way=Config.num_way,
                                  num_shot=Config.num_shot,
                                  episode_size=Config.episode_size,
                                  test_episode=Config.test_episode,
                                  transform=self.task_train.transform_test)

        # loss
        self.cosine_loss = torch.nn.CosineEmbeddingLoss()
        self.triple_margin_loss = torch.nn.TripletMarginLoss()
        pass

    def load_model(self):
        if os.path.exists(Config.pn_dir):
            self.proto_net.load_state_dict(torch.load(Config.pn_dir))
            Tools.print("load proto net success from {}".format(Config.pn_dir))
        pass

    def mixup_loss2(self, z, beta_lambda):
        batch_size, num, c, w, h = z.shape
        z = z.view(batch_size, num, -1)
        x_a, x_1, x_2, x_a1, x_12, x_a2 = [
            torch.squeeze(one) for one in z.split(split_size=1, dim=1)
        ]

        bl_tile = RunnerTool.to_cuda(
            torch.tensor(np.tile(beta_lambda[..., None], [c * w * h])))
        p_1_1, p_2_1, p_3_1, p_1_2, p_2_2, p_3_2 = [
            torch.squeeze(one) for one in bl_tile.split(split_size=1, dim=1)
        ]

        targets_0 = RunnerTool.to_cuda(torch.tensor(-np.ones(batch_size)))
        targets_1 = RunnerTool.to_cuda(torch.tensor(np.ones(batch_size)))

        cosine_1 = self.cosine_loss(x_1, x_a, targets_1)
        cosine_2 = self.cosine_loss(x_1, x_2, targets_0)
        cosine_3 = self.cosine_loss(x_a, x_2, targets_0)
        cosine_4 = self.cosine_loss(x_a1, x_a, targets_1)
        cosine_5 = self.cosine_loss(x_a1, x_1, targets_1)
        cosine_6 = self.cosine_loss(x_a1, x_2, targets_0)
        loss_cosine = cosine_1 + cosine_2 + cosine_3 + cosine_4 + cosine_5 + cosine_6

        cosine_7 = self.cosine_loss(p_1_1 * x_a + p_1_2 * x_1, x_a1, targets_1)
        cosine_8 = self.cosine_loss(p_2_1 * x_1 + p_2_2 * x_2, x_12, targets_1)
        cosine_9 = self.cosine_loss(p_3_1 * x_2 + p_3_2 * x_a, x_a2, targets_1)

        loss_mixup = cosine_7 + cosine_8 + cosine_9
        loss_mixup = loss_mixup * Config.mix_ratio

        # 2 loss
        loss = loss_cosine + loss_mixup

        return loss, loss_cosine, loss_mixup

    def mixup_loss(self, z, beta_lambda):
        batch_size, num, c, w, h = z.shape
        z = z.view(batch_size, num, -1)
        x_a, x_1, x_2, x_a1, x_12, x_a2 = [
            torch.squeeze(one) for one in z.split(split_size=1, dim=1)
        ]

        bl_tile = RunnerTool.to_cuda(
            torch.tensor(np.tile(beta_lambda[..., None], [c * w * h])))
        p_1_1, p_2_1, p_3_1, p_1_2, p_2_2, p_3_2 = [
            torch.squeeze(one) for one in bl_tile.split(split_size=1, dim=1)
        ]

        triple_1 = self.triple_margin_loss(
            x_1, x_a, x_2) + self.triple_margin_loss(x_a, x_1, x_2)
        triple_2 = self.triple_margin_loss(
            x_a1, x_a, x_2) + self.triple_margin_loss(x_a1, x_1, x_2)
        loss_triple = triple_1 + triple_2

        mixup_1 = torch.mean(
            torch.sum(torch.pow(p_1_1 * x_a + p_1_2 * x_1 - x_a1, 2), dim=1))
        mixup_2 = torch.mean(
            torch.sum(torch.pow(p_2_1 * x_1 + p_2_2 * x_2 - x_12, 2), dim=1))
        mixup_3 = torch.mean(
            torch.sum(torch.pow(p_3_1 * x_2 + p_3_2 * x_a - x_a2, 2), dim=1))
        loss_mixup = mixup_1 + mixup_2 + mixup_3
        loss_mixup = loss_mixup * Config.mix_ratio

        # 2 loss
        loss = loss_triple + loss_mixup

        return loss, loss_triple, loss_mixup

    def proto_test(self, samples, batches):
        batch_num, _, _, _ = batches.shape

        sample_z = self.proto_net(samples)  # 5x64*5*5
        batch_z = self.proto_net(batches)  # 75x64*5*5

        sample_z = self.norm(sample_z)
        batch_z = self.norm(batch_z)

        sample_z = sample_z.view(Config.num_way, Config.num_shot, -1)
        batch_z = batch_z.view(batch_num, -1)
        _, z_dim = batch_z.shape

        z_proto = sample_z.mean(1)
        z_proto_expand = z_proto.unsqueeze(0).expand(batch_num, Config.num_way,
                                                     z_dim)
        z_query_expand = batch_z.unsqueeze(1).expand(batch_num, Config.num_way,
                                                     z_dim)

        dists = torch.pow(z_query_expand - z_proto_expand, 2).sum(2)
        log_p_y = F.log_softmax(-dists, dim=1)
        return log_p_y

    def train(self):
        Tools.print()
        Tools.print("Training...")

        for epoch in range(Config.train_epoch):
            self.proto_net.train()

            Tools.print()
            all_loss, all_loss_triple, all_loss_mixup = 0.0, 0.0, 0.0
            for task_tuple, inputs, task_index in tqdm(self.task_train_loader):
                batch_size, num, c, w, h = inputs.shape

                # beta = np.random.beta(1, 1, [batch_size, num])  # 64, 3
                beta = np.zeros([batch_size, num]) + 0.5  # 64, 3

                beta_lambda = np.hstack([beta, 1 - beta])  # 64, 6
                beta_lambda_tile = np.tile(beta_lambda[..., None, None, None],
                                           [c, w, h])

                inputs_1 = torch.cat(
                    [inputs, inputs[:, 1:, ...], inputs[:, 0:1, ...]],
                    dim=1) * beta_lambda_tile
                inputs_1 = (inputs_1[:, 0:num, ...] +
                            inputs_1[:, num:, ...]).float()
                now_inputs = torch.cat([inputs, inputs_1],
                                       dim=1).view(-1, c, w, h)
                now_inputs = RunnerTool.to_cuda(now_inputs)

                # 1 calculate features
                net_out = self.proto_net(now_inputs)
                net_out = self.norm(net_out)

                _, out_c, out_w, out_h = net_out.shape
                z = net_out.view(batch_size, -1, out_c, out_w, out_h)

                # 2 calculate loss
                loss, loss_triple, loss_mixup = self.mixup_loss(
                    z, beta_lambda=beta_lambda)
                all_loss += loss.item()
                all_loss_mixup += loss_mixup.item()
                all_loss_triple += loss_triple.item()

                # 3 backward
                self.proto_net.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.proto_net.parameters(),
                                               0.5)
                self.proto_net_optim.step()
                ###########################################################################
                pass

            ###########################################################################
            # print
            Tools.print(
                "{:6} loss:{:.3f} triple:{:.3f} mixup:{:.3f} lr:{}".format(
                    epoch + 1, all_loss / len(self.task_train_loader),
                    all_loss_triple / len(self.task_train_loader),
                    all_loss_mixup / len(self.task_train_loader),
                    self.proto_net_scheduler.get_last_lr()))

            self.proto_net_scheduler.step()
            ###########################################################################

            ###########################################################################
            # Val
            if epoch % Config.val_freq == 0:
                Tools.print()
                Tools.print("Test {} {} .......".format(
                    epoch, Config.model_name))
                self.proto_net.eval()

                val_accuracy = self.test_tool.val(episode=epoch, is_print=True)
                if val_accuracy > self.best_accuracy:
                    self.best_accuracy = val_accuracy
                    torch.save(self.proto_net.state_dict(), Config.pn_dir)
                    Tools.print("Save networks for epoch: {}".format(epoch))
                    pass
                pass
            ###########################################################################
            pass

        pass

    pass
示例#12
0
def main(args):
    # CUDA
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")
    print("Using cuda: ", use_cuda)

    # Environment
    env_id = "PongNoFrameskip-v4"
    env = make_atari(env_id)
    env = wrap_deepmind(env, args.frame_stack)
    env = wrap_pytorch(env)

    # Random seed
    env.seed(args.seed)
    torch.manual_seed(args.seed)

    # Initializing
    replay_initial = 10000  #50000
    replay_buffer = ReplayBuffer(args.capacity)
    # model = QLearner(env, args, replay_buffer)
    # Initialize target q function and q function
    model_Q = QLearner(env, args, replay_buffer)
    model_target_Q = QLearner(env, args, replay_buffer)

    if args.optimizer == 'Adam':
        if args.use_optim_scheduler:
            optimizer = optim.Adam(model_Q.parameters(), lr=args.initial_lr)
            scheduler = StepLR(optimizer,
                               step_size=args.step_size,
                               gamma=args.gamma)
            # scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=1000, verbose=True)
        else:
            optimizer = optim.Adam(model_Q.parameters(), args.lr)

    elif args.optimizer == 'RMSprop':
        optimizer = optim.RMSprop(model_Q.parameters(), args.lr)

    if USE_CUDA:
        model_Q = model_Q.cuda()
        model_target_Q = model_target_Q.cuda()

    # Training loop
    epsilon_by_frame = lambda frame_idx: args.epsilon_final + (
        args.epsilon_start - args.epsilon_final) * math.exp(-1. * frame_idx /
                                                            args.epsilon_decay)

    losses = []
    learning_rates = []
    all_rewards = []
    episode_reward = 0
    num_param_updates = 0
    mean_reward = -float('nan')
    mean_reward2 = -float('nan')
    best_mean_reward = -float('inf')
    best_mean_reward2 = -float('inf')

    best_18_reward = -float('inf')
    best_19_reward = -float('inf')
    best_20_reward = -float('inf')
    best_21_reward = -float('inf')

    time_history = []  # records time (in sec) of each episode
    old_lr = args.initial_lr
    state = env.reset()
    start_time_frame = time.time()
    for frame_idx in range(1, args.num_frames + 1):
        start_time = time.time()

        epsilon = epsilon_by_frame(frame_idx)
        action = model_Q.act(state, epsilon)

        next_state, reward, done, _ = env.step(action)
        replay_buffer.push(state, action, reward, next_state, done)

        state = next_state
        episode_reward += reward
        if done:
            state = env.reset()
            all_rewards.append(episode_reward)
            time_history.append(time.time() - start_time)
            episode_reward = 0

        if args.render == 1:
            env.render()

        if len(replay_buffer) > replay_initial:
            for nou in range(args.number_of_updates):
                loss = compute_td_loss(model_Q, model_target_Q,
                                       args.batch_size, args.gamma,
                                       replay_buffer, args.N)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                losses.append(loss.data.cpu().numpy())

                num_param_updates += 1
            # Periodically update the target network by Q network to target Q network
            if num_param_updates % args.target_update_freq == 0:
                model_target_Q.load_state_dict(model_Q.state_dict())

            if args.use_optim_scheduler:
                # scheduler.step(mean_reward2)
                scheduler.step()
                new_lr = scheduler.get_last_lr()
                # new_lr = optimizer.param_groups[0]['lr']
                if new_lr != old_lr:
                    learning_rates.append(new_lr)
                    print('NewLearningRate: ', new_lr)
                old_lr = new_lr

        if frame_idx % 10000 == 0 and len(replay_buffer) <= replay_initial:
            print("Preparing replay buffer with len -- ", len(replay_buffer),
                  "Frame:", frame_idx, "Total time so far:",
                  (time.time() - start_time_frame))

        if frame_idx % 10000 == 0 and len(replay_buffer) > replay_initial:
            mean_reward = np.mean(all_rewards[-10:])
            mean_reward2 = np.mean(all_rewards[-100:])
            best_mean_reward = max(best_mean_reward, mean_reward)
            best_mean_reward2 = max(best_mean_reward2, mean_reward2)
            print("Frame:", frame_idx, "Loss:", np.mean(losses),
                  "Total Rewards:",
                  all_rewards[-1], "Average Rewards over all frames:",
                  np.mean(all_rewards), "Last-10 average reward:", mean_reward,
                  "Best mean reward of last-10:", best_mean_reward,
                  "Last-100 average reward:", mean_reward2,
                  "Best mean reward of last-100:", best_mean_reward2, "Time:",
                  time_history[-1], "Total time so far:",
                  (time.time() - start_time_frame))
            if mean_reward >= 18.0:
                if mean_reward > best_18_reward:
                    best_18_reward = mean_reward
                    torch.save(model_Q.state_dict(), args.save_interim_path + \
                              'fmodel_best_18_lr%s_frame_%s_framestack_%s_scheduler_%s_%s.pth'\
                               %(args.lr,frame_idx, args.frame_stack, args.use_optim_scheduler, args.interim_fn))
            if mean_reward >= 19.0:
                if mean_reward > best_19_reward:
                    best_19_reward = mean_reward
                    torch.save(model_Q.state_dict(), args.save_interim_path + \
                              'fmodel_best_19_lr%s_frame_%s_framestack_%s_scheduler_%s_%s.pth'\
                               %(args.lr,frame_idx, args.frame_stack, args.use_optim_scheduler, args.interim_fn))
            if mean_reward >= 20.0:
                if mean_reward > best_20_reward:
                    best_20_reward = mean_reward
                    torch.save(model_Q.state_dict(), args.save_interim_path + \
                              'fmodel_best_20_lr%s_frame_%s_framestack_%s_scheduler_%s_%s.pth'\
                               %(args.lr,frame_idx, args.frame_stack, args.use_optim_scheduler, args.interim_fn))
            if mean_reward >= 21.0:
                if mean_reward > best_21_reward:
                    best_21_reward = mean_reward
                    torch.save(model_Q.state_dict(), args.save_interim_path + \
                              'fmodel_best_21_lr%s_frame_%s_framestack_%s_scheduler_%s_%s.pth'\
                               %(args.lr,frame_idx, args.frame_stack, args.use_optim_scheduler, args.interim_fn))

        if frame_idx % args.save_freq_frame == 0:
            results = [losses, all_rewards, time_history]
            torch.save(model_Q.state_dict(), args.save_model_path)
            np.save(args.save_result_path, results)
        if frame_idx == 10000:
            results = [losses, all_rewards, time_history]
            torch.save(model_Q.state_dict(), args.save_interim_path + \
                      'fmodel_lr%s_frame_%s_framestack_%s_scheduler_%s_%s.pth'\
                       %(args.lr,frame_idx, args.frame_stack, args.use_optim_scheduler, args.interim_fn))
            np.save(args.save_interim_path + \
                   'fresults_lr%s_frame_%s_framestack_%s_scheduler_%s_%s.npy' \
                    %(args.lr, frame_idx, args.frame_stack, args.use_optim_scheduler, args.interim_fn), \
                    results)

        if frame_idx % 500000 == 0:
            results = [losses, all_rewards, time_history]
            torch.save(model_Q.state_dict(), args.save_interim_path + \
                      'fmodel_lr%s_frame_%s_framestack_%s_scheduler_%s_%s.pth' \
                      %(args.lr,frame_idx, args.frame_stack, args.use_optim_scheduler, args.interim_fn))
            np.save(args.save_interim_path + \
                   'fresults_lr%s_frame_%s_framestack_%s_scheduler_%s_%s.npy' \
                   %(args.lr,frame_idx, args.frame_stack, args.use_optim_scheduler, args.interim_fn), \
                    results)
示例#13
0
def main(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # delete args.output_dir if the flag is set and the directory exists
    if args.clear_output_dir and args.output_dir.exists():
        rmtree(args.output_dir)
    args.output_dir.mkdir(parents=True, exist_ok=True)
    args.checkpoint_dir = args.output_dir / 'checkpoints'
    args.checkpoint_dir.mkdir(parents=True, exist_ok=True)

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    args.device = torch.device("cuda" if args.cuda else "cpu")

    train_loader, val_loader, test_loader = get_loaders(args)

    summary = Summary(args)

    scaler = GradScaler(enabled=args.mixed_precision)
    args.output_logits = (args.loss in ['bce', 'binarycrossentropy']
                          and args.model != 'identity')

    model = get_model(args, summary)
    if args.weights_dir is not None:
        model = utils.load_weights(args, model)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = StepLR(optimizer,
                       step_size=5,
                       gamma=args.gamma,
                       verbose=args.verbose == 2)
    loss_function = utils.get_loss_function(name=args.loss)

    critic = None if args.critic is None else Critic(args, summary=summary)

    utils.save_args(args)

    args.global_step = 0
    for epoch in range(args.epochs):
        print(f'Epoch {epoch + 1:03d}/{args.epochs:03d}')
        start = time()
        train_results = train(args,
                              model=model,
                              data=train_loader,
                              optimizer=optimizer,
                              loss_function=loss_function,
                              scaler=scaler,
                              summary=summary,
                              epoch=epoch,
                              critic=critic)
        val_results = validate(args,
                               model=model,
                               data=val_loader,
                               loss_function=loss_function,
                               summary=summary,
                               epoch=epoch,
                               critic=critic)
        end = time()

        scheduler.step()

        summary.scalar('elapse', end - start, step=epoch, mode=0)
        summary.scalar('lr', scheduler.get_last_lr()[0], step=epoch, mode=0)
        summary.scalar('gradient_scale',
                       scaler.get_scale(),
                       step=epoch,
                       mode=0)

        print(f'Train\t\tLoss: {train_results["Loss"]:.04f}\n'
              f'Validation\tLoss: {val_results["Loss"]:.04f}\t'
              f'MAE: {val_results["MAE"]:.04f}\t'
              f'PSNR: {val_results["PSNR"]:.02f}\t'
              f'SSIM: {val_results["SSIM"]:.04f}\n')

    utils.save_model(args, model)

    test(args,
         model=model,
         data=test_loader,
         loss_function=loss_function,
         summary=summary,
         epoch=args.epochs,
         critic=critic)

    summary.close()
示例#14
0
        for fold in ["train", "val"]:
            print("*** Epoch {} - {} ***".format(epoch + 1, fold))
            if fold == "train":
                for i, (inputs, targets) in tqdm(enumerate(train_loader),
                                                 total=len(train_loader)):
                    inputs, targets = inputs.to(device), targets.to(device)
                    # if i == 0:
                    #     visualize_batch(inputs, targets, epoch)
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    if scheduler is not None:
                        current_lr = scheduler.get_last_lr()[0]
                    else:
                        current_lr = params.lr

                    log_stats(train_stats, outputs, targets, loss, current_lr)
            elif fold == "val":
                with torch.no_grad():
                    for i, (inputs, targets) in tqdm(enumerate(val_loader),
                                                     total=len(val_loader)):
                        inputs, targets = inputs.to(device), targets.to(device)
                        outputs = model(inputs)
                        loss = criterion(outputs, targets)
                        log_stats(val_stats, outputs, targets, loss)

        # Progress optimizer scheduler
        if scheduler is not None:
示例#15
0
            model.eval()

        for inputs, labels in tqdm(dataloader[phase], disable=True):
            inputs = inputs.to(device)
            labels = labels.long().squeeze().to(device)

            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(inputs).squeeze()
                loss = criterion(outputs, labels)

                if phase == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            preds = torch.max(softmax(outputs), 1)[1]
            y_trues = np.append(y_trues, labels.data.cpu().numpy())
            y_preds = np.append(y_preds, preds.cpu())

        # if phase == 'train':
        #     scheduler.step()

        epoch_loss = running_loss / dataset_sizes[phase]

        print("[{}] Epoch: {}/{} Loss: {} LR: {}".format(
            phase, epoch + 1, num_epochs, epoch_loss, scheduler.get_last_lr()),
              flush=True)
        print('\nconfusion matrix\n' + str(confusion_matrix(y_trues, y_preds)))
        print('\naccuracy\t' + str(accuracy_score(y_trues, y_preds)))
class Runner(object):

    def __init__(self):
        self.best_accuracy = 0.0

        # all data
        self.data_train = MiniImageNetDataset.get_data_all(Config.data_root)
        self.task_train = MiniImageNetDataset(self.data_train, Config.num_way, Config.num_shot)
        self.task_train_loader = DataLoader(self.task_train, Config.batch_size, True, num_workers=Config.num_workers)

        # IC
        self.produce_class = ProduceClass(len(self.data_train), Config.ic_out_dim, Config.ic_ratio)
        self.produce_class.init()

        # model
        self.proto_net = RunnerTool.to_cuda(Config.proto_net)
        self.ic_model = RunnerTool.to_cuda(Config.ic_proto_net)

        RunnerTool.to_cuda(self.proto_net.apply(RunnerTool.weights_init))
        RunnerTool.to_cuda(self.ic_model.apply(RunnerTool.weights_init))

        # optim
        self.proto_net_optim = torch.optim.Adam(self.proto_net.parameters(), lr=Config.learning_rate)
        self.ic_model_optim = torch.optim.Adam(self.ic_model.parameters(), lr=Config.learning_rate)

        self.proto_net_scheduler = StepLR(self.proto_net_optim, Config.train_epoch // 3, gamma=0.5)
        self.ic_model_scheduler = StepLR(self.ic_model_optim, Config.train_epoch // 3, gamma=0.5)

        # loss
        self.ic_loss = RunnerTool.to_cuda(nn.CrossEntropyLoss())

        # Eval
        self.test_tool_fsl = TestTool(self.proto_test, data_root=Config.data_root,
                                      num_way=Config.num_way, num_shot=Config.num_shot,
                                      episode_size=Config.episode_size, test_episode=Config.test_episode,
                                      transform=self.task_train.transform_test)
        self.test_tool_ic = ICTestTool(feature_encoder=self.proto_net, ic_model=self.ic_model,
                                       data_root=Config.data_root, batch_size=Config.batch_size,
                                       num_workers=Config.num_workers, ic_out_dim=Config.ic_out_dim)
        pass

    def load_model(self):
        if os.path.exists(Config.pn_dir):
            self.proto_net.load_state_dict(torch.load(Config.pn_dir))
            Tools.print("load feature encoder success from {}".format(Config.pn_dir))

        if os.path.exists(Config.ic_dir):
            self.ic_model.load_state_dict(torch.load(Config.ic_dir))
            Tools.print("load ic model success from {}".format(Config.ic_dir))
        pass

    def proto(self, task_data):
        data_batch_size, data_image_num, data_num_channel, data_width, data_weight = task_data.shape
        data_x = task_data.view(-1, data_num_channel, data_width, data_weight)
        net_out = self.proto_net(data_x)
        z = net_out.view(data_batch_size, data_image_num, -1)

        z_support, z_query = z.split(Config.num_shot * Config.num_way, dim=1)
        z_batch_size, z_num, z_dim = z_support.shape
        z_support = z_support.view(z_batch_size, Config.num_way, Config.num_shot, z_dim)

        z_support_proto = z_support.mean(2)
        z_query_expand = z_query.expand(z_batch_size, Config.num_way, z_dim)

        dists = torch.pow(z_query_expand - z_support_proto, 2).sum(2)
        log_p_y = F.log_softmax(-dists, dim=1)
        return log_p_y, z_query.squeeze(1)

    def proto_test(self, samples, batches):
        batch_num, _, _, _ = batches.shape

        sample_z = self.proto_net(samples)  # 5x64*5*5
        batch_z = self.proto_net(batches)  # 75x64*5*5
        sample_z = sample_z.view(Config.num_way, Config.num_shot, -1)
        batch_z = batch_z.view(batch_num, -1)
        _, z_dim = batch_z.shape

        z_proto = sample_z.mean(1)
        z_proto_expand = z_proto.unsqueeze(0).expand(batch_num, Config.num_way, z_dim)
        z_query_expand = batch_z.unsqueeze(1).expand(batch_num, Config.num_way, z_dim)

        dists = torch.pow(z_query_expand - z_proto_expand, 2).sum(2)
        log_p_y = F.log_softmax(-dists, dim=1)
        return log_p_y

    def train(self):
        Tools.print()
        Tools.print("Training...")

        # Init Update
        # try:
        #     self.proto_net.eval()
        #     self.ic_model.eval()
        #     Tools.print("Init label {} .......")
        #     self.produce_class.reset()
        #     with torch.no_grad():
        #         for task_data, task_labels, task_index in tqdm(self.task_train_loader):
        #             ic_labels = RunnerTool.to_cuda(task_index[:, -1])
        #             task_data, task_labels = RunnerTool.to_cuda(task_data), RunnerTool.to_cuda(task_labels)
        #             log_p_y, query_features = self.proto(task_data)
        #             ic_out_logits, ic_out_l2norm = self.ic_model(query_features)
        #             self.produce_class.cal_label(ic_out_l2norm, ic_labels)
        #             pass
        #         pass
        #     Tools.print("Epoch: {}/{}".format(self.produce_class.count, self.produce_class.count_2))
        # finally:
        #     pass

        for epoch in range(Config.train_epoch):
            self.proto_net.train()
            self.ic_model.train()

            Tools.print()
            self.produce_class.reset()
            all_loss, all_loss_fsl, all_loss_ic = 0.0, 0.0, 0.0
            for task_data, task_labels, task_index in tqdm(self.task_train_loader):
                ic_labels = RunnerTool.to_cuda(task_index[:, -1])
                task_data, task_labels = RunnerTool.to_cuda(task_data), RunnerTool.to_cuda(task_labels)

                ###########################################################################
                # 1 calculate features
                log_p_y, query_features = self.proto(task_data)
                ic_out_logits, ic_out_l2norm = self.ic_model(query_features)

                # 2
                ic_targets = self.produce_class.get_label(ic_labels)
                self.produce_class.cal_label(ic_out_l2norm, ic_labels)

                # 3 loss
                loss_fsl = -(log_p_y * task_labels).sum() / task_labels.sum() * Config.loss_fsl_ratio
                loss_ic = self.ic_loss(ic_out_logits, ic_targets) * Config.loss_ic_ratio
                loss = loss_fsl + loss_ic
                all_loss += loss.item()
                all_loss_fsl += loss_fsl.item()
                all_loss_ic += loss_ic.item()

                # 4 backward
                self.proto_net.zero_grad()
                self.ic_model.zero_grad()
                loss.backward()
                # torch.nn.utils.clip_grad_norm_(self.proto_net.parameters(), 0.5)
                # torch.nn.utils.clip_grad_norm_(self.ic_model.parameters(), 0.5)
                self.proto_net_optim.step()
                self.ic_model_optim.step()
                ###########################################################################
                pass

            ###########################################################################
            # print
            Tools.print("{:6} loss:{:.3f} fsl:{:.3f} ic:{:.3f} lr:{}".format(
                epoch + 1, all_loss / len(self.task_train_loader), all_loss_fsl / len(self.task_train_loader),
                all_loss_ic / len(self.task_train_loader), self.proto_net_scheduler.get_last_lr()))
            Tools.print("Train: [{}] {}/{}".format(epoch, self.produce_class.count, self.produce_class.count_2))
            self.proto_net_scheduler.step()
            self.ic_model_scheduler.step()
            ###########################################################################

            ###########################################################################
            # Val
            if epoch % Config.val_freq == 0:
                self.proto_net.eval()
                self.ic_model.eval()

                self.test_tool_ic.val(epoch=epoch)
                val_accuracy = self.test_tool_fsl.val(episode=epoch, is_print=True)

                if val_accuracy > self.best_accuracy:
                    self.best_accuracy = val_accuracy
                    torch.save(self.proto_net.state_dict(), Config.pn_dir)
                    torch.save(self.ic_model.state_dict(), Config.ic_dir)
                    Tools.print("Save networks for epoch: {}".format(epoch))
                    pass
                pass
            ###########################################################################
            pass

        pass

    pass
示例#17
0
def main():
    
    expdir = os.path.join("exp", "train_" + args.tag)
    model_dir = os.path.join(expdir,"models")
    log_dir = os.path.join(expdir,"log")
    # args.model_dir = model_dir
    for x in ['exp', expdir, model_dir, log_dir]:
        if not os.path.isdir(x):
            os.mkdir(x)

    logfilename = os.path.join(log_dir, "log.txt")
    init_logfile(
        logfilename, 
        "arch={} epochs={} batch={} lr={} lr_step={} gamma={} noise_sd={} k_value={} eps_step={}".format(
        args.arch, args.epochs, args.batch, args.lr, args.lr_step_size, 
        args.gamma, args.noise_sd, args.k_value, args.eps_step))
    log(logfilename, "epoch\ttime\tlr\ttrain loss\ttrain acc\tval loss\tval acc")

    
    cifar_train = datasets.CIFAR10("./dataset_cache", train=True, download=True, transform=transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ]))
    cifar_val = datasets.CIFAR10("./dataset_cache", train=False, download=True, transform=transforms.ToTensor())


    train_loader = DataLoader(cifar_train, shuffle=True, batch_size=args.batch, num_workers=args.workers)
    val_loader = DataLoader(cifar_val, shuffle=False, batch_size=args.batch,num_workers=args.workers)
    
    # model = get_architecture(args.arch)
    if args.arch == "resnet18":
        model = torchvision.models.resnet18(pretrained=False, progress=True, **{"num_classes": get_num_classes()}).to(device)
    elif args.arch == "resnet34":
        model = torchvision.models.resnet34(pretrained=False, progress=True, **{"num_classes": get_num_classes()}).to(device)
    elif args.arch == "resnet50":
        model = torchvision.models.resnet50(pretrained=False, progress=True, **{"num_classes": get_num_classes()}).to(device)
    else:
        model = torchvision.models.resnet18(pretrained=False, progress=True, **{"num_classes": get_num_classes()}).to(device)
    
    criterion = CrossEntropyLoss()
    optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer, step_size=args.lr_step_size, gamma=args.gamma, verbose=True)

    for i in range(args.epochs):
        before = time.time()
        train_loss, train_acc = train(train_loader, model, criterion, optimizer, scheduler, i)
        val_loss, val_acc = validate(val_loader, model, criterion)
        after = time.time()

        log(logfilename, "{}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
            i, float(after - before),
            float(scheduler.get_last_lr()[0]), train_loss, train_acc, val_loss, val_acc))

        torch.save(
            {
            'epoch': i,
            'arch': args.arch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            
            }, os.path.join(model_dir,"ep{}.pth".format(i)))
示例#18
0
        init_params)  #Same parameter initialisation as for early stopping
    optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)

    updates_counter = 0
    epoch = 1
    while updates_counter < opt_upd:
        updates_counter = P7.early_stopping_retrain(args, model, device,
                                                    train_loader, optimizer,
                                                    epoch, opt_upd,
                                                    updates_counter, scheduler,
                                                    upd_epoch)
        print('')
        epoch += 1
    print('Learning rate when ending training: {:.4g}\n'.format(
        scheduler.get_last_lr()[0]))

    ### Test time ###
    if use_cuda is True:
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        ACC, y_hat = P7.test(model, device, test_loader)
        end.record()
        # Waits for everything to finish running
        torch.cuda.synchronize()
        test_time = start.elapsed_time(end) / 1000  #GPU time in seconds
        print('{:.4g} seconds'.format(test_time))
    else:
        time1 = time()
        ACC, y_hat = P7.test(model, device, test_loader)
示例#19
0
class VAE(object):
    def __init__(self, x_dim, config, logger):

        self.device = torch.device(config.device)

        self.model = simple_VAE_model(x_dim=x_dim,
                                      hidden_dims=config.hidden_dims,
                                      z_dim=config.z_dim,
                                      logger=logger).to(device=self.device)
        self.optimizer = Adam(self.model.parameters(), lr=config.lr, eps=1e-5)
        self.scheduler = StepLR(self.optimizer,
                                step_size=1,
                                gamma=config.lr_decay_per_update)

        self.beta = config.beta

        self.metrics_to_record = {
            'update_i', 'eval_update_i', 'eval_epoch_i', 'epoch_i',
            'reconstruction_loss', 'prior_loss', 'total_loss',
            'wallclock_time', 'epoch_time', 'lr'
        }

    def update_parameters(self, batch_data, update_i):

        x_batch = torch.FloatTensor(batch_data).to(self.device)

        z_batch, mu_batch, std_batch = self.model.encode(x=batch_data)
        x_hat_batch = self.model.decode(z_batch)

        # Reconstruction loss (MSE between original and reconstructed sample)
        reconstruction_loss = F.mse_loss(
            input=x_hat_batch, target=x_batch,
            reduction='none').sum(dim=1).mean(dim=0)

        # Prior loss (KL-divergence between posterior and prior distributions)
        # see: https://mr-easy.github.io/2020-04-16-kl-divergence-between-2-gaussian-distributions
        prior_loss = 0.5 * (
            (mu_batch.square().sum(dim=1) + std_batch.square().sum(dim=1) -
             z_batch.shape[1] -
             torch.log(std_batch.square().prod(1)))).mean(dim=0)

        # Taking a gradient step
        total_loss = reconstruction_loss + self.beta * prior_loss

        self.model.zero_grad()
        total_loss.backward()
        self.optimizer.step()
        self.scheduler.step()

        # Bookkeeping

        new_recordings = {
            "update_i": update_i,
            "reconstruction_loss": reconstruction_loss.item(),
            "prior_loss": prior_loss.item(),
            "total_loss": total_loss.item(),
            "lr": self.scheduler.get_last_lr()
        }
        return new_recordings

    # Save model parameters
    def save_model(self, path, logger):
        self.to(torch.device("cpu"))
        logger.info(f'Saving models to {path}')
        torch.save(self.model.state_dict(), str(path / "vae_model.pt"))
        self.to(torch.device(self.device))

    # Load model parameters
    def load_model(self, path, logger):
        logger.info(f'Loading models from {path} and {path}')
        self.model.load_state_dict(torch.load(path / "vae_model.pt"))
        self.to(self.device)

    # Send models to different device
    def to(self, device):
        self.model.to(device)

    # Pytorch modules that wandb can monitor
    def wandb_watchable(self):
        return [self.model]

    # Graphs
    def create_plots(self, train_recorder, save_dir):

        fig, axes = create_fig((3, 3))
        plot_curves(
            axes[0, 0],
            xs=[remove_nones(train_recorder.tape['update_i'])],
            ys=[remove_nones(train_recorder.tape['reconstruction_loss'])],
            xlabel='update_i',
            ylabel='reconstruction_loss')
        plot_curves(axes[0, 1],
                    xs=[remove_nones(train_recorder.tape['update_i'])],
                    ys=[remove_nones(train_recorder.tape['prior_loss'])],
                    xlabel='update_i',
                    ylabel='prior_loss')
        plot_curves(axes[0, 2],
                    xs=[remove_nones(train_recorder.tape['update_i'])],
                    ys=[remove_nones(train_recorder.tape['total_loss'])],
                    xlabel='update_i',
                    ylabel='total_loss')
        plot_curves(axes[1, 0],
                    xs=[remove_nones(train_recorder.tape['update_i'])],
                    ys=[remove_nones(train_recorder.tape['lr'])],
                    xlabel='update_i',
                    ylabel='lr')

        plt.tight_layout()

        fig.savefig(str(save_dir / 'graphs.png'))
        plt.close(fig)
示例#20
0
        if reward is not None:  # the episode is done
            ep_time = time.time() - ts
            episode += 1
            total_rewards.append(reward)
            total_losses.append(ep_loss)
            speed = (frame_idx - ts_frame) / ep_time
            ts_frame = frame_idx
            ts = time.time()
            ep_loss = 0
            mean_reward = np.mean(total_rewards[-100:])
            mean_return.append(mean_reward)
            print(
                "%d: done %d games, mean reward %.2f, reward %.2f, lr %.5f, loss %.2f, time %.2f s, eps %.2f, speed %.2f f/s"
                % (frame_idx, episode, mean_reward, reward,
                   scheduler.get_last_lr()[0], total_losses[-1], ep_time,
                   epsilon, speed))
            if best_mean_reward < mean_reward:
                if episode > SAVE_EP:
                    torch.save(net.state_dict(),
                               '{}/episode_{}'.format(out_dir, episode))
                print("mean reward %.2f -> %.2f" %
                      (best_mean_reward, mean_reward))
                best_mean_reward = mean_reward
            if episode == MAX_EP:
                print("Finished in %d frames!" % frame_idx)
                break

        if len(buffer) >= REPLAY_START_SIZE and frame_idx % UPDATE_FRE == 0:
            optimizer.zero_grad()
            batch = buffer.sample(BATCH_SIZE)
示例#21
0
class TD3(object):
    
    # TD3 Constructor
    def __init__(self, state_dim, action_dim, max_action):
        self.actor = Actor(state_dim, action_dim, max_action).to(device)    # Instance of Actor Model
        self.actor_target = Actor(state_dim, action_dim, max_action).to(device) # Instance of Actor Target
        self.actor_target.load_state_dict(self.actor.state_dict())  # Copy weights of Actor Model into Actor Target
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),lr=0.0001)  # Optimizer for training Actor network
        self.actor_lr_scheduler = StepLR(self.actor_optimizer, step_size=10000, gamma=0.9)  # Use StepLR to adjust Learning rate during training
        self.critic = Critic(state_dim, action_dim).to(device)  # Instance of Model Critics
        self.critic_target = Critic(state_dim, action_dim).to(device)   # Instance of Target Critics
        self.critic_target.load_state_dict(self.critic.state_dict())    # Copy weights of Model Critics into Target Citics
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),lr=0.0001)    # Optimizer for training Critic networks
        self.critic_lr_scheduler = StepLR(self.critic_optimizer, step_size=10000, gamma=0.9)    # Use StepLR to adjust Learning rate during training
        self.max_action = max_action    # Maximum value of action

    # Define a method to estimate an action for given state
    def select_action(self, state):
        # first state element is cropped image
        stateImg = np.expand_dims(state[0],0)
        # Convert stateImg to tensor
        stateImg = torch.Tensor(stateImg).to(device)
        # Rest of the elements are float values. Create a float32 numpy array for it
        stateValues = np.array(state[1:], dtype=np.float32)
        # Convert StateValues to tensor with 1 Row 
        stateValues = torch.Tensor(stateValues.reshape(1, -1)).to(device)
        # Set model mode to 'evaluation' so batchNorm, DropOuts must be adjusted accordingly
        self.actor.eval()
        # Pass state through Actor Model forward pass to predict an action.
        # predicted action from tensor to numpy array before returning
        return(self.actor(stateImg,stateValues).cpu().data.numpy().flatten())
    

    def train(self, replay_buffer, iterations, batch_size=100, discount=0.99, tau=0.005, policy_noise=0.2, noise_clip=0.5, policy_freq=2):
        # Keep track of Critic and Actor Loss values
        criticLossAvg = 0.0
        actorLossAvg = 0.0
        for it in range(iterations):
            
            # Step 4: We sample a batch of transitions (s, s’, a, r) from the memory
            batch_stateImgs,batch_stateValues, batch_next_stateImgs,batch_next_stateValues,\
                batch_actions, batch_rewards, batch_dones = replay_buffer.sample(batch_size)
      
            stateImg = torch.Tensor(batch_stateImgs).to(device)
            stateValues = torch.Tensor(batch_stateValues).to(device)
            next_stateImgs = torch.Tensor(batch_next_stateImgs).to(device)
            next_stateValues = torch.Tensor(batch_next_stateValues).to(device)
            action = torch.Tensor(batch_actions).to(device)
            reward = torch.Tensor(batch_rewards).to(device)
            done = torch.Tensor(batch_dones).to(device)
            
            # Step 5: From the next state s’, the Actor target plays the next action a’
            next_action = self.actor_target(next_stateImgs,next_stateValues)
            
            # Step 6: We add Gaussian noise to this next action a’ and we clamp it in a range of values supported by the environment
            noise = torch.Tensor(batch_actions).data.normal_(0, policy_noise).to(device)
            noise = noise.clamp(-noise_clip, noise_clip)
            next_action = (next_action + noise).clamp(-self.max_action, self.max_action)
            
            # Step 7: The two Critic targets take each the couple (s’, a’) as input and return two Q-values Qt1(s’,a’) and Qt2(s’,a’) as outputs
            target_Q1, target_Q2 = self.critic_target(next_stateImgs,next_stateValues, next_action)
            
            # Step 8: We keep the minimum of these two Q-values: min(Qt1, Qt2)
            target_Q = torch.min(target_Q1, target_Q2)
            
            # Step 9: We get the final target of the two Critic models, which is: Qt = r + γ * min(Qt1, Qt2), where γ is the discount factor
            target_Q = reward + ((1 - done) * discount * target_Q).detach()
            
            # Step 10: The two Critic models take each the couple (s, a) as input and return two Q-values Q1(s,a) and Q2(s,a) as outputs
            current_Q1, current_Q2 = self.critic(stateImg,stateValues,action)
            
            # Step 11: We compute the loss coming from the two Critic models: Critic Loss = MSE_Loss(Q1(s,a), Qt) + MSE_Loss(Q2(s,a), Qt)
            critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
            #print(critic_loss.item(),type(critic_loss.item()))
        
            criticLossAvg += critic_loss.item()
            
            # Step 12: We backpropagate this Critic loss and update the parameters of the two Critic models with a SGD optimizer
            self.critic.train() # Set Model mode to 'training'
            self.critic_optimizer.zero_grad()   # reset grad values
            critic_loss.backward()              # update grad values
            self.critic_optimizer.step()        # update model parameters
            self.critic_lr_scheduler.step()     # step through critic LR scheduler
            
            # Step 13: Once every two iterations, we update our Actor model by performing gradient ascent on the output of the first Critic model
            if it % policy_freq == 0:
                actor_loss = -self.critic.Q1(stateImg,stateValues, self.actor(stateImg,stateValues)).mean()
                actorLossAvg += actor_loss.item()
                self.actor.train()  # Set Model mode to 'training'
                self.actor_optimizer.zero_grad()    # Reset grad values
                actor_loss.backward()               # update grad values
                self.actor_optimizer.step()         # update model parameters
                self.actor_lr_scheduler.step()      # step through actor LR scheduler
                
                # Step 14: Still once every two iterations, we update the weights of the Actor target by polyak averaging
                for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                  target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
                
                # Step 15: Still once every two iterations, we update the weights of the Critic target by polyak averaging
                for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                  target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
                  
        criticLossAvg /= iterations 
        actorLossAvg /= iterations       
        print('Avg CriticLoss: ',criticLossAvg,' Avg ActorLoss ',actorLossAvg, \
              ' ActorLR: ',self.actor_lr_scheduler.get_last_lr(),' CriticLR: ',self.critic_lr_scheduler.get_last_lr())
  
    def save(self, filename, directory):
          torch.save(self.actor.state_dict(), '%s/%s_actor.pth' % (directory, filename))
          torch.save(self.critic.state_dict(), '%s/%s_critic.pth' % (directory, filename))
  
    def load(self, filename, directory):
          self.actor.load_state_dict(torch.load('%s/%s_actor.pth' % (directory, filename),map_location=torch.device('cpu')))
          self.critic.load_state_dict(torch.load('%s/%s_critic.pth' % (directory, filename),map_location=torch.device('cpu')))
          
          
#policy = TD3(4, 1, 5.0)          
示例#22
0
                    'epoch %2d, iter %4d: loss: %.4f (%.4f) | load time: %.2f | back time: %.2f'
                    % (epoch + 1, iter + 1, loss_confmap.item(), loss_ave,
                       load_time, back_time))
                with open(train_log_name, 'a+') as f_log:
                    f_log.write(
                        'epoch %2d, iter %4d: loss: %.4f (%.4f) | load time: %.2f | back time: %.2f\n'
                        % (epoch + 1, iter + 1, loss_confmap.item(), loss_ave,
                           load_time, back_time))

                writer.add_scalar('loss/loss_all', loss_confmap.item(),
                                  iter_count)
                writer.add_scalar('loss/loss_ave', loss_ave, iter_count)
                writer.add_scalar('time/time_load', load_time, iter_count)
                writer.add_scalar('time/time_back', back_time, iter_count)
                writer.add_scalar('param/param_lr',
                                  scheduler.get_last_lr()[0], iter_count)

                if stacked_num is not None:
                    confmap_pred = confmap_preds[stacked_num -
                                                 1].cpu().detach().numpy()
                else:
                    confmap_pred = confmap_preds.cpu().detach().numpy()

                if 'mnet_cfg' in model_cfg:
                    chirp_amp_curr = chirp_amp(data.numpy()[0, :, 0, 0, :, :],
                                               radar_configs['data_type'])
                else:
                    chirp_amp_curr = chirp_amp(data.numpy()[0, :, 0, :, :],
                                               radar_configs['data_type'])

                # draw train images
示例#23
0
class Solver():
    def __init__(self, model, config, dataloader, optimizer, stamp, val_step=10,
                 lr_decay_step=None, lr_decay_rate=None, bn_decay_step=None, bn_decay_rate=None):

        self.epoch = 0  # set in __call__
        self.verbose = 0  # set in __call__

        self.model = model
        self.config = config
        self.dataloader = dataloader
        self.optimizer = optimizer
        self.stamp = stamp
        self.val_step = val_step

        self.lr_decay_step = lr_decay_step
        self.lr_decay_rate = lr_decay_rate
        self.bn_decay_step = bn_decay_step
        self.bn_decay_rate = bn_decay_rate

        self.best = {
            "epoch": 0,
            "loss": float("inf"),
            "ref_loss": float("inf"),
            "lang_loss": float("inf"),
            "lang_acc": -float("inf"),
            "ref_acc": -float("inf"),
            "iou_rate_0.25": -float("inf"),
            "iou_rate_0.5": -float("inf")
        }

        # log
        self.init_log()

        # tensorboard
        os.makedirs(os.path.join(CONF.PATH.OUTPUT, stamp, "tensorboard/train"), exist_ok=True)
        os.makedirs(os.path.join(CONF.PATH.OUTPUT, stamp, "tensorboard/val"), exist_ok=True)
        self._log_writer = {
            "train": SummaryWriter(os.path.join(CONF.PATH.OUTPUT, stamp, "tensorboard/train")),
            "val": SummaryWriter(os.path.join(CONF.PATH.OUTPUT, stamp, "tensorboard/val"))
        }

        # training log
        log_path = os.path.join(CONF.PATH.OUTPUT, stamp, "log.txt")
        self.log_fout = open(log_path, "a")

        # private
        # only for internal access and temporary results
        self._running_log = {}
        self._global_iter_id = 0
        self._total_iter = {}  # set in __call__

        # templates
        self.__iter_report_template = ITER_REPORT_TEMPLATE
        self.__epoch_report_template = EPOCH_REPORT_TEMPLATE
        self.__best_report_template = BEST_REPORT_TEMPLATE

        # lr scheduler
        if lr_decay_step and lr_decay_rate:
            if isinstance(lr_decay_step, list):
                self.lr_scheduler = MultiStepLR(optimizer, lr_decay_step, lr_decay_rate)
            else:
                self.lr_scheduler = StepLR(optimizer, lr_decay_step, lr_decay_rate)
        else:
            self.lr_scheduler = None

        # bn scheduler
        if bn_decay_step and bn_decay_rate:
            it = -1
            start_epoch = 0
            BN_MOMENTUM_INIT = 0.5
            BN_MOMENTUM_MAX = 0.001
            bn_lbmd = lambda it: max(BN_MOMENTUM_INIT * bn_decay_rate ** (int(it / bn_decay_step)), BN_MOMENTUM_MAX)
            self.bn_scheduler = BNMomentumScheduler(model, bn_lambda=bn_lbmd, last_epoch=start_epoch - 1)
        else:
            self.bn_scheduler = None

    def __call__(self, epoch, verbose):
        # setting
        self.epoch = epoch
        self.verbose = verbose
        self._total_iter["train"] = len(self.dataloader["train"]) * epoch
        self._total_iter["val"] = len(self.dataloader["val"]) * self.val_step

        for epoch_id in range(epoch):
            try:
                self._log("epoch {} starting...".format(epoch_id + 1))

                # feed
                self._feed(self.dataloader["train"], "train", epoch_id)

                # save model
                self._log("saving last models...\n")
                model_root = os.path.join(CONF.PATH.OUTPUT, self.stamp)
                torch.save(self.model.state_dict(), os.path.join(model_root, "model_last.pth"))

                print("evaluating...")
                self.init_log()
                # val
                self._feed(self.dataloader["val"], "val", epoch_id)

                # update lr scheduler
                if self.lr_scheduler:
                    self.lr_scheduler.step()
                    self._log("update learning rate --> {}\n".format(self.lr_scheduler.get_last_lr()))

                # update bn scheduler
                if self.bn_scheduler:
                    self.bn_scheduler.step()
                    self._log("update batch normalization momentum --> {}\n".format(
                        self.bn_scheduler.lmbd(self.bn_scheduler.last_epoch)))

            except KeyboardInterrupt:
                # finish training
                self._finish(epoch_id)
                exit()

        # finish training
        self._finish(epoch_id)

    def _log(self, info_str):
        self.log_fout.write(info_str + "\n")
        self.log_fout.flush()
        print(info_str)

    def _set_phase(self, phase):
        if phase == "train":
            self.model.train()
        elif phase == "val":
            self.model.eval()
        else:
            raise ValueError("invalid phase")

    def _forward(self, data_dict):
        data_dict = self.model(data_dict)

        return data_dict

    def _backward(self):
        # optimize
        self.optimizer.zero_grad()
        self._running_log["loss"].backward()
        # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.optimizer.step()

    def _compute_loss(self, data_dict):
        data_dict = get_loss(
            data_dict=data_dict,
            config=self.config,
        )

        # dump
        self._running_log["ref_loss"] = data_dict["ref_loss"]
        self._running_log["lang_loss"] = data_dict["lang_loss"]
        self._running_log["seg_loss"] = data_dict["seg_loss"]
        self._running_log["loss"] = data_dict["loss"]

    def _eval(self, data_dict):
        data_dict = get_eval(
            data_dict=data_dict,
            config=self.config,
        )

        # dump
        self._running_log["lang_acc"] = data_dict["lang_acc"].item()
        self._running_log["ref_acc"] = np.mean(data_dict["ref_acc"])
        self._running_log["seg_acc"] = data_dict["seg_acc"].item()
        self._running_log['ref_iou'] = data_dict['ref_iou']

    def _feed(self, dataloader, phase, epoch_id):
        # switch mode
        self._set_phase(phase)

        # change dataloader
        dataloader = dataloader if phase == "train" else tqdm(dataloader)
        fetch_time_start = time.time()

        for data_dict in dataloader:

            # move to cuda
            for key in data_dict:
                if key in ['lang_feat', 'lang_len', 'object_cat', 'lidar', 'point_min', 'point_max', 'mlm_label',
                           'ref_center_label', 'ref_size_residual_label']:
                    data_dict[key] = data_dict[key].cuda()

            # initialize the running loss
            self._running_log = {
                # loss
                "loss": 0,
                "ref_loss": 0,
                "lang_loss": 0,
                "seg_loss": 0,
                # acc
                "lang_acc": 0,
                "ref_acc": 0,
                "seg_acc": 0,
                "iou_rate_0.25": 0,
                "iou_rate_0.5": 0
            }

            # load
            self.log[phase]["fetch"].append(time.time() - fetch_time_start)

            # debug only
            # with torch.autograd.set_detect_anomaly(True):
            # forward
            start = time.time()
            data_dict = self._forward(data_dict)
            self._compute_loss(data_dict)
            self.log[phase]["forward"].append(time.time() - start)

            # backward
            if phase == "train":
                start = time.time()
                self._backward()
                self.log[phase]["backward"].append(time.time() - start)

            # eval
            start = time.time()
            self._eval(data_dict)
            self.log[phase]["eval"].append(time.time() - start)

            # record log
            self.log[phase]["loss"].append(self._running_log["loss"].item())
            self.log[phase]["ref_loss"].append(self._running_log["ref_loss"].item())
            self.log[phase]["lang_loss"].append(self._running_log["lang_loss"].item())
            self.log[phase]["seg_loss"].append(self._running_log["seg_loss"].item())

            self.log[phase]["lang_acc"].append(self._running_log["lang_acc"])
            self.log[phase]["ref_acc"].append(self._running_log["ref_acc"])
            self.log[phase]["seg_acc"].append(self._running_log["seg_acc"])
            self.log[phase]['ref_iou'] += self._running_log['ref_iou']

            ious = self.log[phase]['ref_iou']
            self.log[phase]['iou_rate_0.25'] = np.array(ious)[np.array(ious) >= 0.25].shape[0] / np.array(ious).shape[0]
            self.log[phase]['iou_rate_0.5'] = np.array(ious)[np.array(ious) >= 0.5].shape[0] / np.array(ious).shape[0]

            # report
            if phase == "train":
                iter_time = self.log[phase]["fetch"][-1]
                iter_time += self.log[phase]["forward"][-1]
                iter_time += self.log[phase]["backward"][-1]
                iter_time += self.log[phase]["eval"][-1]
                self.log[phase]["iter_time"].append(iter_time)
                if (self._global_iter_id + 1) % self.verbose == 0:
                    self._train_report(epoch_id)
                    # dump log
                    self._dump_log("train")
                    self.init_log()

                self._global_iter_id += 1
            fetch_time_start = time.time()

        # check best
        if phase == "val":
            ious = self.log[phase]['ref_iou']
            self.log[phase]['iou_rate_0.25'] = np.array(ious)[np.array(ious) >= 0.25].shape[0] / np.array(ious).shape[0]
            self.log[phase]['iou_rate_0.5'] = np.array(ious)[np.array(ious) >= 0.5].shape[0] / np.array(ious).shape[0]

            self._dump_log("val")
            self._epoch_report(epoch_id)

            cur_criterion = "iou_rate_0.25"
            cur_best = self.log[phase][cur_criterion]
            if cur_best > self.best[cur_criterion]:
                self._log("best {} achieved: {}".format(cur_criterion, cur_best))
                self.best["epoch"] = epoch_id + 1
                self.best["loss"] = np.mean(self.log[phase]["loss"])
                self.best["ref_loss"] = np.mean(self.log[phase]["ref_loss"])
                self.best["lang_loss"] = np.mean(self.log[phase]["lang_loss"])
                self.best["seg_loss"] = np.mean(self.log[phase]["seg_loss"])
                self.best["lang_acc"] = np.mean(self.log[phase]["lang_acc"])
                self.best["ref_acc"] = np.mean(self.log[phase]["ref_acc"])
                self.best["seg_acc"] = np.mean(self.log[phase]["seg_acc"])
                self.best["iou_rate_0.25"] = self.log[phase]['iou_rate_0.25']
                self.best["iou_rate_0.5"] = self.log[phase]['iou_rate_0.5']

                # save model
                self._log("saving best models...\n")
                model_root = os.path.join(CONF.PATH.OUTPUT, self.stamp)
                torch.save(self.model.state_dict(), os.path.join(model_root, "model.pth"))

    def _dump_log(self, phase):
        log = {
            "loss": ["loss", "ref_loss", "lang_loss", "seg_loss"],
            "score": ["lang_acc", "ref_acc", "seg_acc"]
        }
        for key in log:
            for item in log[key]:
                self._log_writer[phase].add_scalar(
                    "{}/{}".format(key, item),
                    np.mean([v for v in self.log[phase][item]]),
                    self._global_iter_id
                )

        self._log_writer[phase].add_scalar(
            "{}/{}".format("score", 'iou_rate_0.25'),
            self.log[phase]['iou_rate_0.25'],
            self._global_iter_id
        )
        self._log_writer[phase].add_scalar(
            "{}/{}".format("score", 'iou_rate_0.5'),
            self.log[phase]['iou_rate_0.5'],
            self._global_iter_id
        )


    def _finish(self, epoch_id):
        # print best
        self._best_report()

        # save check point
        self._log("saving checkpoint...\n")
        save_dict = {
            "epoch": epoch_id,
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict()
        }
        checkpoint_root = os.path.join(CONF.PATH.OUTPUT, self.stamp)
        torch.save(save_dict, os.path.join(checkpoint_root, "checkpoint.tar"))

        # save model
        self._log("saving last models...\n")
        model_root = os.path.join(CONF.PATH.OUTPUT, self.stamp)
        torch.save(self.model.state_dict(), os.path.join(model_root, "model_last.pth"))

        # export
        for phase in ["train", "val"]:
            self._log_writer[phase].export_scalars_to_json(
                os.path.join(CONF.PATH.OUTPUT, self.stamp, "tensorboard/{}".format(phase), "all_scalars.json"))

    def _train_report(self, epoch_id):
        # compute ETA
        fetch_time = self.log["train"]["fetch"]
        forward_time = self.log["train"]["forward"]
        backward_time = self.log["train"]["backward"]
        eval_time = self.log["train"]["eval"]
        iter_time = self.log["train"]["iter_time"]

        mean_train_time = np.mean(iter_time)
        mean_est_val_time = np.mean([fetch + forward for fetch, forward in zip(fetch_time, forward_time)])
        eta_sec = (self._total_iter["train"] - self._global_iter_id - 1) * mean_train_time
        eta_sec += len(self.dataloader["val"]) * np.ceil(self._total_iter["train"] / self.val_step) * mean_est_val_time
        eta = decode_eta(eta_sec)

        # print report
        iter_report = self.__iter_report_template.format(
            epoch_id=epoch_id + 1,
            iter_id=self._global_iter_id + 1,
            total_iter=self._total_iter["train"],
            train_loss=round(np.mean([v for v in self.log["train"]["loss"]]), 5),
            train_ref_loss=round(np.mean([v for v in self.log["train"]["ref_loss"]]), 5),
            train_lang_loss=round(np.mean([v for v in self.log["train"]["lang_loss"]]), 5),
            train_seg_loss=round(np.mean([v for v in self.log["train"]["seg_loss"]]), 5),
            train_lang_acc=round(np.mean([v for v in self.log["train"]["lang_acc"]]), 5),
            train_ref_acc=round(np.mean([v for v in self.log["train"]["ref_acc"]]), 5),
            train_seg_acc=round(np.mean([v for v in self.log["train"]["seg_acc"]]), 5),
            train_iou_rate_25=round(self.log['train']['iou_rate_0.25'], 5),
            train_iou_rate_5=round(self.log['train']['iou_rate_0.5'], 5),
            mean_fetch_time=round(np.mean(fetch_time), 5),
            mean_forward_time=round(np.mean(forward_time), 5),
            mean_backward_time=round(np.mean(backward_time), 5),
            mean_eval_time=round(np.mean(eval_time), 5),
            mean_iter_time=round(np.mean(iter_time), 5),
            eta_h=eta["h"],
            eta_m=eta["m"],
            eta_s=eta["s"]
        )
        self._log(iter_report)

    def _epoch_report(self, epoch_id):
        self._log("epoch [{}/{}] done...".format(epoch_id + 1, self.epoch))
        epoch_report = self.__epoch_report_template.format(
            val_loss=round(np.mean([v for v in self.log["val"]["loss"]]), 5),
            val_seg_loss=round(np.mean([v for v in self.log["val"]["seg_loss"]]), 5),
            val_ref_loss=round(np.mean([v for v in self.log["val"]["ref_loss"]]), 5),
            val_lang_loss=round(np.mean([v for v in self.log["val"]["lang_loss"]]), 5),
            val_lang_acc=round(np.mean([v for v in self.log["val"]["lang_acc"]]), 5),
            val_seg_acc=round(np.mean([v for v in self.log["val"]["seg_acc"]]), 5),
            val_ref_acc=round(np.mean([v for v in self.log["val"]["ref_acc"]]), 5),
            val_iou_rate_25=round(self.log['val']['iou_rate_0.25'], 5),
            val_iou_rate_5=round(self.log['val']['iou_rate_0.5'], 5),
        )
        self._log(epoch_report)

    def _best_report(self):
        self._log("training completed...")
        best_report = self.__best_report_template.format(
            epoch=self.best["epoch"],
            loss=round(self.best["loss"], 5),
            ref_loss=round(self.best["ref_loss"], 5),
            lang_loss=round(self.best["lang_loss"], 5),
            lang_acc=round(self.best["lang_acc"], 5),
            ref_acc=round(self.best["ref_acc"], 5),
            iou_rate_25=round(self.best["iou_rate_0.25"], 5),
            iou_rate_5=round(self.best["iou_rate_0.5"], 5),
        )
        self._log(best_report)
        with open(os.path.join(CONF.PATH.OUTPUT, self.stamp, "best.txt"), "w") as f:
            f.write(best_report)

    def init_log(self):
        # contains all necessary info for all phases
        self.log = {
            phase: {
                # info
                "forward": [],
                "backward": [],
                "eval": [],
                "fetch": [],
                "iter_time": [],
                # loss (float, not torch.cuda.FloatTensor)
                "loss": [],
                "ref_loss": [],
                "lang_loss": [],
                "seg_loss": [],
                # scores (float, not torch.cuda.FloatTensor)
                "lang_acc": [],
                "ref_acc": [],
                "seg_acc": [],
                'ref_iou': [],
                "iou_rate_0.25": [],
                "iou_rate_0.5": []
            } for phase in ["train", "val"]
        }
示例#24
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=14,
                        metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr',
                        type=float,
                        default=1.0,
                        metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.7,
                        metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')

    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              **kwargs)

    model = Net().to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)

    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader)

        # added a log line here
        metrics.send_metric("scheduler_lr", scheduler.get_last_lr())

        scheduler.step()

    if args.save_model:
        torch.save(model.state_dict(), "mnist_cnn.pt")
class Runner(object):
    def __init__(self):
        self.best_accuracy = 0.0

        # all data
        self.data_train = MiniImageNetDataset.get_data_all(Config.data_root)
        self.task_train = MiniImageNetDataset(self.data_train, Config.num_way,
                                              Config.num_shot)
        self.task_train_loader = DataLoader(self.task_train,
                                            Config.batch_size,
                                            True,
                                            num_workers=Config.num_workers)

        # model
        self.matching_net = RunnerTool.to_cuda(Config.matching_net)
        RunnerTool.to_cuda(self.matching_net.apply(RunnerTool.weights_init))
        self.norm = Normalize(2)

        self.has_norm = Config.has_norm
        self.has_softmax = Config.has_softmax

        # loss
        self.loss = RunnerTool.to_cuda(nn.MSELoss())

        # optim
        self.matching_net_optim = torch.optim.Adam(
            self.matching_net.parameters(), lr=Config.learning_rate)
        self.matching_net_scheduler = StepLR(self.matching_net_optim,
                                             Config.train_epoch // 3,
                                             gamma=0.5)

        self.test_tool = TestTool(self.matching_test,
                                  data_root=Config.data_root,
                                  num_way=Config.num_way,
                                  num_shot=Config.num_shot,
                                  episode_size=Config.episode_size,
                                  test_episode=Config.test_episode,
                                  transform=self.task_train.transform_test)
        pass

    def load_model(self):
        if os.path.exists(Config.mn_dir):
            self.matching_net.load_state_dict(torch.load(Config.mn_dir))
            Tools.print("load proto net success from {}".format(Config.mn_dir))
        pass

    def matching(self, task_data):
        data_batch_size, data_image_num, data_num_channel, data_width, data_weight = task_data.shape
        data_x = task_data.view(-1, data_num_channel, data_width, data_weight)
        net_out = self.matching_net(data_x)
        z = net_out.view(data_batch_size, data_image_num, -1)

        # 特征
        z_support, z_query = z.split(Config.num_shot * Config.num_way, dim=1)
        z_batch_size, z_num, z_dim = z_support.shape
        z_support = z_support.view(z_batch_size,
                                   Config.num_way * Config.num_shot, z_dim)
        z_query_expand = z_query.expand(z_batch_size,
                                        Config.num_way * Config.num_shot,
                                        z_dim)

        # 相似性
        z_support = self.norm(z_support)
        z_query_expand = self.norm(
            z_query_expand) if self.has_norm else z_query_expand
        similarities = torch.sum(z_support * z_query_expand, -1)
        similarities = torch.softmax(
            similarities, dim=1) if self.has_softmax else similarities
        similarities = similarities.view(z_batch_size, Config.num_way,
                                         Config.num_shot)
        predicts = torch.mean(similarities, dim=-1)
        return predicts

    def matching_test(self, samples, batches):
        batch_num, _, _, _ = batches.shape

        sample_z = self.matching_net(samples)  # 5x64*5*5
        batch_z = self.matching_net(batches)  # 75x64*5*5
        z_support = sample_z.view(Config.num_way * Config.num_shot, -1)
        z_query = batch_z.view(batch_num, -1)
        _, z_dim = z_query.shape

        z_support_expand = z_support.unsqueeze(0).expand(
            batch_num, Config.num_way * Config.num_shot, z_dim)
        z_query_expand = z_query.unsqueeze(1).expand(
            batch_num, Config.num_way * Config.num_shot, z_dim)

        # 相似性
        z_support_expand = self.norm(z_support_expand)
        z_query_expand = self.norm(
            z_query_expand) if self.has_norm else z_query_expand
        similarities = torch.sum(z_support_expand * z_query_expand, -1)
        similarities = torch.softmax(
            similarities, dim=1) if self.has_softmax else similarities
        similarities = similarities.view(batch_num, Config.num_way,
                                         Config.num_shot)
        predicts = torch.mean(similarities, dim=-1)
        return predicts

    def train(self):
        Tools.print()
        Tools.print("Training...")

        for epoch in range(Config.train_epoch):
            self.matching_net.train()

            Tools.print()
            all_loss = 0.0
            for task_data, task_labels, task_index in tqdm(
                    self.task_train_loader):
                task_data, task_labels = RunnerTool.to_cuda(
                    task_data), RunnerTool.to_cuda(task_labels)

                # 1 calculate features
                predicts = self.matching(task_data)

                # 2 loss
                loss = self.loss(predicts, task_labels)
                all_loss += loss.item()

                # 3 backward
                self.matching_net.zero_grad()
                loss.backward()
                self.matching_net_optim.step()
                ###########################################################################
                pass

            ###########################################################################
            # print
            Tools.print("{:6} loss:{:.3f} lr:{}".format(
                epoch + 1, all_loss / len(self.task_train_loader),
                self.matching_net_scheduler.get_last_lr()))

            self.matching_net_scheduler.step()
            ###########################################################################

            ###########################################################################
            # Val
            if epoch % Config.val_freq == 0:
                Tools.print()
                Tools.print("Test {} {} .......".format(
                    epoch, Config.model_name))
                self.matching_net.eval()

                val_accuracy = self.test_tool.val(episode=epoch, is_print=True)
                if val_accuracy > self.best_accuracy:
                    self.best_accuracy = val_accuracy
                    torch.save(self.matching_net.state_dict(), Config.mn_dir)
                    Tools.print("Save networks for epoch: {}".format(epoch))
                    pass
                pass
            ###########################################################################
            pass

        pass

    pass
示例#26
0
def train(params):
    logger = get_logger('{}.log'.format(params['task']),
                        '{}_logger'.format(params['task']))
    logger.info('start {}'.format(params['task']))

    set_all_seed(params['seed'])

    for key, value in params.items():
        logger.info('{} : {}'.format(key, value))

    logger.info('loading seqs, feas and w2v embeddings ...')
    train_val_data, sub_data, embeddings, embed_size, fea_size = load_data(
        params['cols'], params['embed_dir'], params['seqs_file'],
        params['feas_file'])

    logger.info('embed_size : {} | fea_size : {}'.format(embed_size, fea_size))
    batch_size = params['batch_size']
    sub_dataset = SeqDataSet(sub_data['seqs'],
                             sub_data['feas'], sub_data['users'],
                             len(params['cols']), params['max_len'], 'sub')
    sub_loader = data.DataLoader(sub_dataset,
                                 batch_size * 10,
                                 shuffle=False,
                                 collate_fn=sub_dataset.collate_fn,
                                 pin_memory=True)

    sub = np.zeros(shape=(sub_data['num'], 20))
    sub = pd.DataFrame(sub, index=sub_data['users'])

    skf = StratifiedKFold(n_splits=5,
                          shuffle=True,
                          random_state=params['seed'])

    for i, (train_idx, val_idx) in enumerate(
            skf.split(train_val_data['feas'], train_val_data['labels'])):
        logger.info(
            '------------------------------------------{} fold------------------------------------------'
            .format(i))
        train_dataset = SeqDataSet(train_val_data['seqs'][train_idx],
                                   train_val_data['feas'][train_idx],
                                   train_val_data['labels'][train_idx],
                                   len(params['cols']), params['max_len'],
                                   'train')
        train_loader = data.DataLoader(train_dataset,
                                       batch_size,
                                       shuffle=True,
                                       collate_fn=train_dataset.collate_fn,
                                       pin_memory=True)

        val_dataset = SeqDataSet(train_val_data['seqs'][val_idx],
                                 train_val_data['feas'][val_idx],
                                 train_val_data['labels'][val_idx],
                                 len(params['cols']), params['max_len'], 'val')
        val_loader = data.DataLoader(val_dataset,
                                     batch_size * 10,
                                     shuffle=False,
                                     collate_fn=val_dataset.collate_fn,
                                     pin_memory=True)

        logger.info(
            'train samples : {} | val samples : {} | sub samples : {}'.format(
                len(train_idx), len(val_idx), sub_data['num']))
        logger.info('loading net ...')

        embed_net = embedNet(embeddings).cuda()
        net = Net(embed_size, fea_size, params['hidden_size'],
                  params['num_layers'], params['drop_out']).cuda()

        #optimizer = Ranger(params=net.parameters(), lr=params['lr'])
        optimizer = optim.AdamW(params=net.parameters(), lr=params['lr'])
        scheduler = StepLR(optimizer, step_size=2, gamma=params['gamma'])
        #scheduler = CosineAnnealingLR(optimizer, T_max=params['num_epochs'])
        loss_func = CrossEntropyLabelSmooth(20, params['label_smooth'])
        #loss_func = nn.CrossEntropyLoss()

        earlystop = EarlyStopping(params['early_stop_round'], logger,
                                  params['task'] + str(i))

        for epoch in range(params['num_epochs']):
            train_loss, val_loss = 0.0, 0.0
            train_age_acc, val_age_acc = 0.0, 0.0
            train_gender_acc, val_gender_acc = 0.0, 0.0
            train_acc, val_acc = 0.0, 0.0

            n, m = 0, 0
            lr_now = scheduler.get_last_lr()[0]
            logger.info('--> [Epoch {:02d}/{:02d}] lr = {:.7f}'.format(
                epoch, params['num_epochs'], lr_now))

            # 训练模型
            net.train()
            for seqs, feas, lens, labels in tqdm(
                    train_loader,
                    desc='[Epoch {:02d}/{:02d}] Train'.format(
                        epoch, params['num_epochs'])):
                seqs = seqs.cuda()
                feas = feas.cuda()
                lens = lens.cuda()
                labels = labels.cuda()

                logits = net(embed_net(seqs), feas, lens)
                loss = loss_func(logits, labels)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                train_loss += loss.detach() * labels.shape[0]
                train_age_acc += (logits.argmax(dim=1).detach() %
                                  10 == labels % 10).sum()
                train_gender_acc += (logits.argmax(dim=1).detach() //
                                     10 == labels // 10).sum()

                n += lens.shape[0]

            scheduler.step()

            train_loss = (train_loss / n).item()
            train_age_acc = (train_age_acc / n).item()
            train_gender_acc = (train_gender_acc / n).item()
            train_acc = train_age_acc + train_gender_acc

            # 预测验证集
            net.eval()
            with torch.no_grad():
                for seqs, feas, lens, labels in tqdm(
                        val_loader,
                        desc='[Epoch {:02d}/{:02d}]  Val '.format(
                            epoch, params['num_epochs'])):
                    seqs = seqs.cuda()
                    feas = feas.cuda()
                    lens = lens.cuda()
                    labels = labels.cuda()

                    logits = net(embed_net(seqs), feas, lens)
                    loss = loss_func(logits, labels)

                    val_loss += loss.detach() * labels.shape[0]
                    val_age_acc += (logits.argmax(dim=1) % 10 == labels %
                                    10).sum()
                    val_gender_acc += (logits.argmax(dim=1).detach() //
                                       10 == labels // 10).sum()

                    m += lens.shape[0]

                val_loss = (val_loss / m).item()
                val_age_acc = (val_age_acc / m).item()
                val_gender_acc = (val_gender_acc / m).item()
                val_acc = val_age_acc + val_gender_acc

            logger.info(
                'train_loss {:.5f} | train_gender_acc {:.5f} | train_age_acc {:.5f} | train_acc {:.5f} | val_loss {:.5f} | val_gender_acc {:.5f} | val_age_acc {:.5f} | val_acc {:.5f}'
                .format(train_loss, train_gender_acc, train_age_acc, train_acc,
                        val_loss, val_gender_acc, val_age_acc, val_acc))

            # 早停
            earlystop(val_loss, val_acc, net)
            if earlystop.early_stop:
                break

        break
        net.load_state_dict(
            torch.load('{}_checkpoint.pt'.format(params['task'] + str(i))))
        logger.info('predicting sub ...')
        net.eval()
        with torch.no_grad():
            for it in range(10):
                probs = []
                users = []
                for seqs, feas, lens, ids in tqdm(
                        sub_loader, desc='predict_{}'.format(it)):
                    seqs = seqs.cuda()
                    feas = feas.cuda()
                    lens = lens.cuda()

                    logits = net(embed_net(seqs), feas, lens)
                    logits = F.softmax(logits, dim=1)

                    probs.append(logits)
                    users.append(ids)

                probs = torch.cat(probs).cpu().numpy()
                users = torch.cat(users).numpy()
                sub += pd.DataFrame(probs, users)
            sub = sub / 10

    return sub