Esempio n. 1
0
    def __init__(self, option, model, train_loader, val_loader, test_loader,
                 optimizer, criterion):
        self.option = option
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.optimizer = optimizer
        self.criterion = criterion

        self.epoch_loss_plotter = tnt.logger.VisdomPlotLogger('line',
                                                              opts={
                                                                  'title':
                                                                  'Epoch Loss',
                                                                  'xlabel':
                                                                  "Epochs",
                                                                  'ylabel':
                                                                  "Loss"
                                                              })
        self.batch_loss_plotter = IncrementVisdomLineLogger(opts={
            'title': 'Batch Loss',
            'xlabel': "Batch",
            'ylabel': "Loss"
        })

        self.checkpoint = Checkpoint(option)
        self.best_top1 = 0
        self.start_epoch = 0
        self._load_checkpoint()
Esempio n. 2
0
    def train(self,
              model,
              data,
              num_epochs=5,
              resume=False,
              dev_data=None,
              optimizer=None,
              teacher_forcing_ratio=0):
        """ Run training for a given model.

        Args:
            model (seq2seq.models): model to run training on, if `resume=True`, it would be
               overwritten by the model loaded from the latest checkpoint.
            data (seq2seq.dataset.dataset.Dataset): dataset object to train on
            num_epochs (int, optional): number of epochs to run (default 5)
            resume(bool, optional): resume training with the latest checkpoint, (default False)
            dev_data (seq2seq.dataset.dataset.Dataset, optional): dev Dataset (default None)
            optimizer (seq2seq.optim.Optimizer, optional): optimizer for training
               (default: Optimizer(pytorch.optim.Adam, max_grad_norm=5))
            teacher_forcing_ratio (float, optional): teaching forcing ratio (default 0)
        Returns:
            model (seq2seq.models): trained model.
        """
        # If training is set to resume
        if resume:
            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                self.expt_dir)
            resume_checkpoint = Checkpoint.load(latest_checkpoint_path)
            model = resume_checkpoint.model
            self.optimizer = resume_checkpoint.optimizer

            # A walk around to set optimizing parameters properly
            resume_optim = self.optimizer.optimizer
            defaults = resume_optim.param_groups[0]
            defaults.pop('params', None)
            defaults.pop('initial_lr', None)
            self.optimizer.optimizer = resume_optim.__class__(
                model.parameters(), **defaults)

            start_epoch = resume_checkpoint.epoch
            step = resume_checkpoint.step
        else:
            start_epoch = 1
            step = 0
            if optimizer is None:
                optimizer = Optimizer(optim.Adam(model.parameters()),
                                      max_grad_norm=5)
            self.optimizer = optimizer

        self.logger.info("Optimizer: %s, Scheduler: %s" %
                         (self.optimizer.optimizer, self.optimizer.scheduler))

        self._train_epochs(data,
                           model,
                           num_epochs,
                           start_epoch,
                           step,
                           dev_data=dev_data,
                           teacher_forcing_ratio=teacher_forcing_ratio)
        return model
Esempio n. 3
0
def main():
    train_loader, test_loader = get_mnist_data('../%s' % opt.dataset,
                                               opt.batch_size)
    model = CapsuleNetwork(opt)
    if opt.cuda == True:
        model = model.cuda()

    if opt.is_train == True:
        if opt.resume == True:
            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                opt.save_folder)
            resume_checkpoint = Checkpoint.load(latest_checkpoint_path)
            model = resume_checkpoint.model
            optimizer = resume_checkpoint.optimizer
            start_epoch = resume_checkpoint.epoch + 1
        else:
            start_epoch = 0
            optimizer = Adam(model.parameters())

        for epoch in range(start_epoch, opt.n_epochs):
            train(epoch, model, train_loader, test_loader, optimizer)
            Checkpoint(model=model, optimizer=optimizer,
                       epoch=epoch).save(opt.save_folder)
    else:
        run_test(model, test_loader)
Esempio n. 4
0
    def __init__(self,
                 option,
                 model,
                 train_loader,
                 val_loader,
                 test_loader,
                 optimizer,
                 criterion,
                 client_loaders,
                 sybil_loaders,
                 iidness=[.0, .0]):
        self.option = option
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.optimizer = optimizer
        self.criterion = criterion
        self.iidness = iidness

        self.epoch_loss_plotter = tnt.logger.VisdomPlotLogger('line',
                                                              opts={
                                                                  'title':
                                                                  'Epoch Loss',
                                                                  'xlabel':
                                                                  "Epochs",
                                                                  'ylabel':
                                                                  "Loss"
                                                              })
        self.batch_loss_plotter = IncrementVisdomLineLogger(opts={
            'title': 'Batch Loss',
            'xlabel': "Batch",
            'ylabel': "Loss"
        })
        self.train_confusion_plotter = tnt.logger.VisdomLogger(
            'heatmap',
            opts={
                'title': 'Train Confusion matrix',
                'columnnames': list(range(option.n_classes)),
                'rownames': list(range(option.n_classes))
            })
        self.val_confusion_plotter = tnt.logger.VisdomLogger(
            'heatmap',
            opts={
                'title': 'Val Confusion matrix',
                'columnnames': list(range(option.n_classes)),
                'rownames': list(range(option.n_classes))
            })

        self.memory = None
        self.wv_history = []
        self.client_loaders = client_loaders
        self.sybil_loaders = sybil_loaders

        self.checkpoint = Checkpoint(option)
        self.best_top1 = 0
        self.start_epoch = 0
        self._load_checkpoint()
 def checkpoint(self, cp):
     "Create a checkpoint from arbitrary object 'cp'"
     checkpoint = Checkpoint(self, cp, author = "anonymous")
     valid = checkpoint.test()
     print "Tested checkpoint %r and got result %r" % (cp, valid)
     if valid:
         if self.owner:
             checkpoint.quorum.sign(self.identity)
             self.owner.attempt_checkpoint(self, checkpoint)
         else:
             checkpoint.enact()
Esempio n. 6
0
class Game:
    def __init__(self, width, height, bg_color):
        pygame.init()
        self.displaySurface = pygame.display.set_mode((width, height))
        self.bgColor = bg_color
        pygame.display.set_caption("GOKU")

        self.tilemap = tmx.load("test.tmx", self.displaySurface.get_size())
        self.players = tmx.SpriteLayer()
        self.enemies = tmx.SpriteLayer()
        self.blasts = tmx.SpriteLayer()
        self.hud = tmx.SpriteLayer()
        self.dragonball = tmx.SpriteLayer()
        self.checkpoints = Checkpoint()
        #self.bg = pygame.image.load('res/Map/dbz_background.jpg')
        player_cell = self.tilemap.layers["triggers"].find("player")[0]
        enemy_cells = self.tilemap.layers["triggers"].find("enemy")
        checkpoint_cells = self.tilemap.layers["triggers"].find("checkpoint")
        finish = self.tilemap.layers["triggers"].find("finish")[0]

        for cell in enemy_cells:
            self.enemies.add(Henchmen2((cell.left, cell.bottom)))

        for checkpoint in checkpoint_cells:
            self.checkpoints.add_checkpoint((checkpoint.px, checkpoint.py))

        self.goku = Goku((player_cell.px, player_cell.py))
        #  self.vegeta = Vegeta((400, 200))
        self.players.add(self.goku)
        self.hud.add(Hud(self.goku))
        self.dragonball.add(Dragonball((finish.px, finish.py)))
        # self.players.add(self.vegeta)
        self.fpsClock = pygame.time.Clock()
        self.tilemap.layers.append(self.players)
        self.tilemap.layers.append(self.enemies)
        self.tilemap.layers.append(self.blasts)
        self.tilemap.layers.append(self.hud)
        self.tilemap.layers.append(self.dragonball)

    def main(self):
        while True:  # Main Game Loop
            dt = self.fpsClock.tick(30)
            for event in pygame.event.get():
                if event.type == QUIT:
                    pygame.quit()
                    sys.exit()

            self.displaySurface.fill(self.bgColor)
            #self.displaySurface.blit(self.bg, (0,0))
            # self.tilemap.set_focus(self.goku.rect.x, self.goku.rect.y)
            self.tilemap.update(dt / 1000, self)
            self.tilemap.draw(self.displaySurface)
            pygame.display.update()
Esempio n. 7
0
def train(args):

    configure(args['log_dir'])

    dial_data = get_dataloader(
        os.path.join(args['data_dir'], 'encoded_train_dialogue_pair.json'),
        os.path.join(args['data_dir'], 'vocabulary.json'), args['batch_size'])
    vocab = Vocabulary()
    vocab.load_vocab(os.path.join(args['data_dir'], 'vocabulary.json'))
    args['voca_size'] = len(vocab.word2idx)

    model = Seq2Seq(args).cuda() if torch.cuda.is_available() else Seq2Seq(
        args)

    optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'])

    criterion = nn.NLLLoss(ignore_index=vocab.get_idx('PADED'))

    min_valid_loss = float('inf')

    for epoch in range(args['epoches']):
        for batch_idx, (sour, sour_len, targ,
                        targ_len) in enumerate(dial_data):
            if torch.cuda.is_available():
                sour = sour.cuda()
                targ = targ.cuda()
            loss = train_batch(model, optimizer, criterion,
                               (sour, sour_len, targ, targ_len))

            logger.info('training loss:{}'.format(loss))
            log_value('CrossEntropy loss', loss,
                      epoch * len(dial_data) + batch_idx)

            if (batch_idx + epoch * len(dial_data)) % args['valid_step'] == 0:
                valid_loader = get_dataloader(
                    os.path.join(args['data_dir'],
                                 'encoded_valid_dialogue_pair.json'),
                    os.path.join(args['data_dir'], 'vocabulary.json'),
                    args['batch_size'])
                valid_loss = validate(model, valid_loader, criterion)

                log_value(
                    'valid loss', valid_loss,
                    int((batch_idx + epoch * len(dial_data)) /
                        args['valid_step']))
                logger.info('valid_step:{} valid_loss:{}'.format(
                    int((batch_idx + epoch * len(dial_data)) /
                        args['valid_step']), valid_loss))

                checkpoint = Checkpoint(model, optimizer, epoch, batch_idx)
                checkpoint.save(args['exp_dir'])
Esempio n. 8
0
    def testCheckLogCreation(self):
        x = Checkpoint(TestCheckpoint.TEST_DIR)

        x.createCheckpointLog(TestCheckpoint.TEST_KEY)
        self.assertTrue(TestCheckpoint.TEST_KEY in x.getCheckpointLogKeys())

        x.releaseCheckpointLog(TestCheckpoint.TEST_KEY)
        self.assertFalse(TestCheckpoint.TEST_KEY in x.getCheckpointLogKeys())
 def row_to_object(row):
     checkpoint_pass_object = CheckpointPass()
     checkpoint_pass_object.id = row.id
     checkpoint_pass_object.user_id = row.user_id
     checkpoint_pass_object.time = row.time
     checkpoint_pass_object.checkpoint = Checkpoint.row_to_object(row.checkpoint)
     return checkpoint_pass_object
Esempio n. 10
0
def test(args):

    vocab = Vocabulary()
    vocab.load_vocab(os.path.join(args['data_dir'], 'vocabulary.json'))
    args['voca_size'] = vocab.get_vocab_size()
    test_data = get_dataloader(
        os.path.join(args['data_dir'], 'encoded_test_dialogue_pair.json'),
        os.path.join(args['data_dir'], 'vocabulary.json'), 1)
    test_sent_pair_list = []

    model = Seq2Seq(args).eval()
    if torch.cuda.is_available():
        model = model.cuda()

    path = Checkpoint.get_latest_checkpoint(args['exp_dir'])
    model.load_state_dict(torch.load(os.path.join(path, 'model.pt')))

    for batch_idx, (sour, sour_len, targ, targ_len) in enumerate(test_data):
        if torch.cuda.is_available():
            sour = sour.cuda()
            targ = targ.cuda()
        enco_hidd_state = model.encoder.encoder_forward(sour, sour_len)
        out_prob = model.decoder.decoder_forward(targ, targ_len,
                                                 enco_hidd_state, 0)
        sent_list = [(out_prob.topk(1)[1].view(-1).tolist(), 0)]
        test_sent_pair_list += process_sent_list(vocab, sour, targ, sent_list)
#   logger.info('batch_idx:{} \nsent:{}'.format(batch_idx,test_sent_pair_list))

    save_test_sent(args['exp_data'], 'generated_test_sent.txt',
                   test_sent_pair_list)
Esempio n. 11
0
    def do_checkpoint(self):
        # when make checkpoint, first write workq and workq_buf into checkpoint file, then make a copy of workq_db if it exists
        for k in self.wfd_cache.keys():
            os.close(self.wfd_cache[k])

        # clear the cache
        self.wfd_cache.clear()

        tmp_file = self.checkpoint_file + ".part"
        with open(tmp_file, "wb") as f:
            self.circle.workq.extend(self.circle.workq_buf)
            self.circle.workq_buf.clear()
            cobj = Checkpoint(self.src, self.dest, self.get_workq(), self.totalsize)
            pickle.dump(cobj, f, pickle.HIGHEST_PROTOCOL)
        # POSIX requires rename to be atomic
        os.rename(tmp_file, self.checkpoint_file)

        # copy workq_db database file
        if hasattr(self.circle, "workq_db") and len(self.circle.workq_db) > 0:
            self.checkpoint_db = self.checkpoint_file + ".db"
            if not G.resume:
                shutil.copy2(self.circle.dbname, self.checkpoint_db)
            else:
                # in resume mode, make a copy of current workq db file, which is provided checkpoint db file
                self.workdir = os.getcwd()
                existingCheckpoint = os.path.join(self.workdir,".pcp_workq.%s.%s.db" % (G.rid, self.circle.rank))
                shutil.copy2(existingCheckpoint,self.checkpoint_db)
Esempio n. 12
0
def load_model():
    checkpoint_path = ""
    if not FLAGS.load_checkpoint is None:
        checkpoint_path = os.path.join(FLAGS.expt_dir,
                                       Checkpoint.CHECKPOINT_DIR_NAME,
                                       FLAGS.load_checkpoint)
    else:
        checkpoint_path = Checkpoint.get_latest_checkpoint(FLAGS.expt_dir)
    logging.info("loading checkpoint from {}".format(checkpoint_path))
    checkpoint = Checkpoint.load(checkpoint_path)
    seq2seq = checkpoint.model
    # these are vocab classes with members stoi and itos
    input_vocab = checkpoint.input_vocab
    output_vocab = checkpoint.output_vocab
    classifier = (seq2seq, input_vocab, output_vocab)

    return classifier
Esempio n. 13
0
    def init_checkpoints(self):
        self.checkpoint_group.add(Checkpoint("goomba_set_1", 1200))
        self.checkpoint_group.add(Checkpoint("goomba_set_2", 2600))
        self.checkpoint_group.add(Checkpoint("goomba_set_3", 3400))
        self.checkpoint_group.add(Checkpoint("goomba_set_4", 4100))
        self.checkpoint_group.add(Checkpoint("goomba_set_5", 4610))
        self.checkpoint_group.add(Checkpoint("goomba_set_6", 4720))
        self.checkpoint_group.add(Checkpoint("goomba_set_7", 6210))

        self.checkpoint_group.add(Checkpoint("koopa_set_1", 3830))
Esempio n. 14
0
	def testCheckLogCreation(self):
		x = Checkpoint(TestCheckpoint.TEST_DIR)
		
		x.createCheckpointLog(TestCheckpoint.TEST_KEY)
		self.assertTrue(TestCheckpoint.TEST_KEY in x.getCheckpointLogKeys())

		x.releaseCheckpointLog(TestCheckpoint.TEST_KEY)
		self.assertFalse(TestCheckpoint.TEST_KEY in x.getCheckpointLogKeys())
Esempio n. 15
0
def main():
    if (len(sys.argv) != 5):
        print "Usage: " + sys.argv[
            0] + " <directory> <album name> <app key> <app secret>"
        sys.exit(1)

    dir = sys.argv[1]
    checkpoint = Checkpoint(dir)

    photos = get_photos(dir)

    upload_photos(sys.argv[2], photos, sys.argv[3], sys.argv[4], checkpoint)
Esempio n. 16
0
    def train(self,encoder, decoder, n_epochs, train_data, dev_data,
                resume, optimizer, log_file):
        """
        ------------------------------------------------------------------------
        Args:
            encoder:                  Self explanatory.
            decoder:                  Self explanatory.
            n_epoch (int):            Number of epochs to train the model.
            train_data (Composition): Self explanatory.
            dev_data (Composition):   Self explanatory.
            resume (bool):            If true, load last checkpoint.
        ------------------------------------------------------------------------
        """
        if resume:
            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(self.exp_dir)
            resume_checkpoint = Checkpoint.load(latest_checkpoint_path)
            encoder        = resume_checkpoint.encoder
            decoder        = resume_checkpoint.decoder
            start_epoch    = resume_checkpoint.epoch
            step           = resume_checkpoint.step
            self.scheduler = resume_checkpoint.scheduler
            self.optimizer = resume_checkpoint.optimizer
            self.samp_rate = resume_checkpoint.samp_rate
            self.KL_rate   = resume_checkpoint.KL_rate
            self.free_bits = resume_checkpoint.free_bits
            self.vocab_size = decoder.vocab_size
        else:
            self.optimizer = optimizer
            if optimizer is None:
                params = list(encoder.parameters()) + list(decoder.parameters())
                self.optimizer = Adam(params, lr=1e-3)
            self.scheduler = LambdaLR(self.optimizer,decay)
            self.vocab_size = decoder.vocab_size

            start_epoch = 1
            step = 0

        self.train_epochs(encoder, decoder, start_epoch, step, train_data, dev_data,
                        start_epoch + n_epochs, log_file)
        return encoder,decoder
Esempio n. 17
0
def main():
    # create target output dir if it doesn't exist yet
    if not os.path.isdir(args.output_dir):
        os.mkdir(args.output_dir)

    # enable mixed-precision computation if desired
    if args.amp:
        mixed_precision.enable_mixed_precision()

    # set the RNG seeds (probably more hidden elsewhere...)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # get the dataset
    dataset = get_dataset(args.dataset)
    enc_size = get_encoder_size(dataset)

    # get a helper object for tensorboard logging
    log_dir = os.path.join(args.output_dir, args.run_name)
    stat_tracker = StatTracker(log_dir=log_dir)

    # get dataloaders for training and testing
    train_loader, test_loader, num_classes = \
        build_dataset(dataset=dataset,
                      batch_size=args.batch_size,
                      input_dir=args.input_dir,
                      labeled_only=args.classifiers)

    torch_device = torch.device('cuda')
    # create new model with random parameters
    model = Model(ndf=args.ndf,
                  n_classes=num_classes,
                  n_rkhs=args.n_rkhs,
                  tclip=args.tclip,
                  n_depth=args.n_depth,
                  enc_size=enc_size,
                  use_bn=(args.use_bn == 1))
    model.init_weights(init_scale=1.0)
    # restore model parameters from a checkpoint if requested
    checkpoint = Checkpoint(model, args.cpt_load_path, args.output_dir,
                            args.cpt_name)
    model = model.to(torch_device)

    # select which type of training to do
    task = train_classifiers if args.classifiers else train_self_supervised

    # do the real stuff...
    task(model, args.learning_rate, dataset, train_loader, test_loader,
         stat_tracker, checkpoint, args.output_dir, torch_device)
Esempio n. 18
0
def run_test(model, test_loader):
    latest_checkpoint_path = Checkpoint.get_latest_checkpoint(opt.save_folder)
    resume_checkpoint = Checkpoint.load(latest_checkpoint_path)
    model = resume_checkpoint.model
    optimizer = resume_checkpoint.optimizer

    model.eval()
    test_loss = 0
    num_error = 0
    num_data = 0
    for batch_id, (data, target) in enumerate(test_loader):
        data = Variable(data)
        if opt.cuda == True:
            data = data.cuda()

        output, mask, recon = model(data)
        out_mag = torch.sqrt((output**2).sum(2))
        out_mag = F.softmax(out_mag, dim=1)
        max_val, max_idx = out_mag.max(dim=1)

        for idx in range(data.size(0)):
            print "(batch_index, sample_index, estimated, target) : ", batch_id, idx, max_idx[
                idx].data.cpu().numpy(), target[idx]
            if max_idx[idx].data.cpu().numpy() != target[idx]:
                num_error = num_error + 1
            num_data = num_data + 1
        if opt.vis == True:
            idx = random.randint(0, data.size(0) - 1)
            show_recon = recon[idx].data.cpu().numpy().reshape(28, 28)
            show_data = data[idx].data.cpu().numpy().reshape(28, 28)

            cv2.namedWindow("recon", cv2.WINDOW_NORMAL)
            cv2.imshow("recon", np.concatenate((show_data, show_recon),
                                               axis=1))
            cv2.waitKey(1)
    print 'test error : ', float(num_error) / float(num_data)
Esempio n. 19
0
    def set_checkpoints(self):
        pressed_keys = pygame.key.get_pressed()
        checkpoint_start = None
        checkpoint_id = 0

        while (not pressed_keys[K_RETURN]):
            pressed_keys = pygame.key.get_pressed()
            self.on_render()
            for checkpoint in self._checkpoints:
                checkpoint.draw(self._display_surface)
                pygame.display.update()

            if checkpoint_start is not None:
                pygame.draw.line(self._display_surface, (255, 255, 0),
                                 checkpoint_start,
                                 pygame.mouse.get_pos(),
                                 width=3)
                pygame.display.update()

            for event in pygame.event.get():
                left, _, _ = pygame.mouse.get_pressed()
                if left:
                    if checkpoint_start is None:
                        checkpoint_start = pygame.mouse.get_pos()
                    else:
                        checkpoint = Checkpoint(str(checkpoint_id),
                                                checkpoint_start,
                                                pygame.mouse.get_pos())
                        self._checkpoints.append(checkpoint)
                        checkpoint_start = None
                        checkpoint_id += 1

                # Handle if the D key is pressed
                elif event.type == pygame.KEYDOWN and event.key == K_d:
                    # If the D key was pressed and we were drawing a new checkpoint, cancel
                    if checkpoint_start is not None:
                        checkpoint_start = None
                        continue
                    # Otherwise - remove the last drawn checkpoint
                    elif len(self._checkpoints) > 0:
                        self._checkpoints = self._checkpoints[:-1]
Esempio n. 20
0
    def __init__(self, width, height, bg_color):
        pygame.init()
        self.displaySurface = pygame.display.set_mode((width, height))
        self.bgColor = bg_color
        pygame.display.set_caption("GOKU")

        self.tilemap = tmx.load("test.tmx", self.displaySurface.get_size())
        self.players = tmx.SpriteLayer()
        self.enemies = tmx.SpriteLayer()
        self.blasts = tmx.SpriteLayer()
        self.hud = tmx.SpriteLayer()
        self.dragonball = tmx.SpriteLayer()
        self.checkpoints = Checkpoint()
        #self.bg = pygame.image.load('res/Map/dbz_background.jpg')
        player_cell = self.tilemap.layers["triggers"].find("player")[0]
        enemy_cells = self.tilemap.layers["triggers"].find("enemy")
        checkpoint_cells = self.tilemap.layers["triggers"].find("checkpoint")
        finish = self.tilemap.layers["triggers"].find("finish")[0]

        for cell in enemy_cells:
            self.enemies.add(Henchmen2((cell.left, cell.bottom)))

        for checkpoint in checkpoint_cells:
            self.checkpoints.add_checkpoint((checkpoint.px, checkpoint.py))

        self.goku = Goku((player_cell.px, player_cell.py))
        #  self.vegeta = Vegeta((400, 200))
        self.players.add(self.goku)
        self.hud.add(Hud(self.goku))
        self.dragonball.add(Dragonball((finish.px, finish.py)))
        # self.players.add(self.vegeta)
        self.fpsClock = pygame.time.Clock()
        self.tilemap.layers.append(self.players)
        self.tilemap.layers.append(self.enemies)
        self.tilemap.layers.append(self.blasts)
        self.tilemap.layers.append(self.hud)
        self.tilemap.layers.append(self.dragonball)
Esempio n. 21
0
    def load_objects(self):
        for x, y in place_objects(CHECK):
            self.game_objects["check"].append(
                Checkpoint(x * 8, y * 8, self.plr))

        for x, y in place_objects(SWITCH):
            self.game_objects["switch"].append(Switch(x * 8, y * 8))

        for x, y in place_objects(BAD_ROBOT):
            self.game_objects["robot"].append(BadRobot(x * 8, y * 8, self.cam))

        for x, y in place_objects(BADDER_ROBOT):
            self.game_objects["robot"].append(
                BadRobot(x * 8, y * 8, self.cam, True))

        for x, y in place_objects(MOVING_PLATFORM):
            self.game_objects["platform"].append(MobilePlatform(x * 8, y * 8))

        for x, y in place_objects(MOVING_PLATFORM_OPPOSITE):
            self.game_objects["platform"].append(
                MobilePlatform(x * 8, y * 8, True))

        for x, y in place_objects(LASER):
            self.game_objects["laser"].append(
                Laser(x * 8, y * 8, self, self.cam, 3))

        for x, y in place_objects(FAST_LASER):
            self.game_objects["laser"].append(
                Laser(x * 8, y * 8, self, self.cam, 2))

        for x, y in place_objects(VERY_FAST_LASER):
            self.game_objects["laser"].append(
                Laser(x * 8, y * 8, self, self.cam, 1))

        for id in range(1, GATE_IDS + 1):
            for x, y in place_objects(GATE_START_ADDRESS + id):
                self.game_objects["gate"].append(Gate(x * 8, y * 8, id))
Esempio n. 22
0
 def from_dict(_dict):
     gid = _dict["gid"]
     gname = _dict["gname"]
     status = _dict.get("status", None)
     hostID = _dict["hostID"]
     hostName = _dict["hostName"]
     hostAvatar = _dict["hostAvatar"]
     checkpoints = [Checkpoint.from_dict(cp) for cp in _dict["checkpoints"]]
     players = [
         Player.from_dict(player) for player in _dict.get("players", [])
     ]
     min_players, max_players = _dict.get("min_players",
                                          6), _dict.get("max_players", 20)
     teams = _dict.get("teams", None)
     startTime = _dict.get("startTime", None)
     endTime = _dict.get("endTime", None)
     capturedCount = _dict.get("capturedCount", None)
     unCapturedCount = _dict.get("unCapturedCount", None)
     winTeam = _dict.get("winTeam", None)
     statsCount = dict.get(_dict, "statsCount", 0)
     return Game(gid, gname, status, hostID, hostName, hostAvatar,
                 checkpoints, players, teams, min_players, max_players,
                 startTime, endTime, capturedCount, unCapturedCount,
                 winTeam, statsCount)
Esempio n. 23
0
class Processor(object):
    """Class for processing dump files from postgresql."""
    MILLION = 1024 * 1024

    def __init__(self):
        self.bytes_count = 0
        self.start_time = 0.0
        self.out_files = {}
        self.checkpoint = Checkpoint(config.VALUE_SET)
        self.init_time()

    def init_time(self):
        """Init time."""

        self.start_time = time.time()

    def add_bytes_count(self, count: int):
        """Add up bytes count."""

        self.bytes_count += count

    def split_if_necessary(self) -> None:
        """Check size of each storage file, called each batch
        close and open a new one to store if size exceeds max_split_size
        """

        # Convert MB to Byte
        for v in config.VALUE_SET:
            file_size = self.out_files[v].tell()
            if file_size >= config.FILE_SPLIT_SIZE:
                self.checkpoint.update_file_index(v)
                new_file = open(
                    self.checkpoint.get_file_name(v, config.OUT_DIR), 'a')
                self.add_table_head(new_file)
                self.out_files[v].close()
                self.out_files[v] = new_file
                logging.info('File size grows over {:.2f} MB, '
                             'store in new file `{}`...'.format(
                                 config.FILE_SPLIT_SIZE / self.MILLION,
                                 new_file.name))

    def process_line(self, line: str) -> None:
        """Process each line, does NOT verify the validness of
         lines (print them and ignores invalid ones without terminating)
         check if this line is recorded, and record the line.

        :param line: str, line to process ('\n' not included)
        """

        attributes = line.split('\t')
        try:
            # Check value in values to group by
            value = attributes[config.GROUP_BY_ATTR_INDEX]
            if value not in config.VALUE_SET:
                return
            row_count = int(attributes[config.INDEX_ROW_COUNT])
            # Check if line is already parsed and recorded
            if row_count <= self.checkpoint.row_count[value]:
                return
            # Keep attributes we're interested in
            data = [attributes[i] for i in config.RECORD_ATTR_INDEX_LIST]
            # Write to related file
            self.out_files[value].write('\t'.join(data))
            self.out_files[value].write('\n')
            # Update index
            self.checkpoint.row_count[value] = row_count
        except Exception as e:
            logging.warning(e)
            logging.warning("Invalid row: {}".format(attributes))

    @staticmethod
    def verify_file_schema(fp: TextIO) -> bool:
        """Verify the schema of data contained in a file.
        The dump files of postgresql should contain exactly one table each.
        """

        line = fp.readline()
        # Remember to return head of file
        fp.seek(0)
        if isinstance(line, bytes):
            line = str(line, encoding='utf-8')
        # Remove empty cells
        attributes = list(filter(None, line.split('\t')))
        # Check attribute count
        if len(attributes) != config.ATTR_COUNT:
            return False
        # Check validness of index attribute
        try:
            _ = int(attributes[config.INDEX_ROW_COUNT])
        except ValueError:
            return False
        return True

    @staticmethod
    def add_table_head(f: TextIO) -> None:
        """Add headings of table."""

        f.write('\t'.join(config.RECORD_ATTR_LIST))
        f.write('\n')

    def process_file(self, filename: str, is_old_file: bool = False) -> None:
        """Process a text file (ends with '.dat') or gzip file (ends with .gz).

        :param filename: str, name of file to process
        :param is_old_file: bool, whether this file has been processed before
                if it has been, we should skip batches already read.
        :return: int, 0 if this file is ignored or 1 if processed
        """

        # Check file type
        file_type = filename[filename.rfind('.'):]
        if file_type not in config.OPEN_FUNCS:
            logging.info('Fail to process `{}`: unsupported file type.'.format(
                filename))
            return
        # Open file according to its type
        fp = config.OPEN_FUNCS[file_type](filename)

        # Old file: needs to recover to the starting point
        if is_old_file and self.checkpoint.offset > 0:
            fp.seek(self.checkpoint.offset)
            logging.info('Time for seeking file offset: {:.2f} s'.format(
                time.time() - self.start_time))
            # This should be the start of processing
            self.init_time()
        else:
            # New files:
            # needs to verify whether this file contains the table we want
            if not self.verify_file_schema(fp):
                logging.info(
                    'Schema of `{}` doesn\'t fit; skip.'.format(filename))
                fp.close()
                return
            # Record current file
            self.checkpoint.current_file = filename

        logging.info('Start processing `{}`...'.format(filename))
        while True:
            self.checkpoint.offset = fp.tell()
            batch = fp.read(config.BATCH_SIZE)
            # EOF
            line = fp.readline()
            if line:
                batch += line
            if not batch:
                break
            # Convert from bytes to str if needed
            if isinstance(batch, bytes):
                batch = str(batch, 'utf-8')
            # Parse batch
            for line in batch.splitlines():
                self.process_line(line)
            self.add_bytes_count(len(batch))
            # Split large files and change storage to new files
            if config.SPLIT:
                self.split_if_necessary()
        fp.close()

    def process_dir(self, dirname: str) -> None:
        """Recursively process files in given directory.

        :param dirname: str, directory of files to precess
        :return: number of files processed under this directory
        """

        file_list = sorted(os.listdir(dirname))
        for name in file_list:
            # Full name of file
            name = os.path.join(dirname, name)
            # Check if this file is already processed
            if name in self.checkpoint.processed_files:
                continue
            if os.path.isfile(name):
                self.process_file(name)
                self.checkpoint.processed_files.add(name)
            elif os.path.isdir(name) and config.RECURSIVE:
                self.process_dir(name)

    def before_process(self) -> None:
        """Create directory if needed, and load records."""
        if not os.path.isdir(config.OUT_DIR):
            os.mkdir(config.OUT_DIR)
        # Load checkpoints from file
        if os.path.exists(config.RECORD_FILE):
            self.checkpoint.load(config.RECORD_FILE)
            logging.info('Checkpoint loaded from `{}`.'.format(
                config.RECORD_FILE))
        # Open files to write
        for v in config.VALUE_SET:
            f = open(self.checkpoint.get_file_name(v, config.OUT_DIR), 'a')
            # If it's a new file, add headings
            if f.tell() == 0:
                self.add_table_head(f)
            self.out_files[v] = f

    def process(self, dir_list: list) -> None:
        """Process list of directories / files"""
        try:
            # Prepare for processing
            self.before_process()
            # Recover from file processed last time
            if os.path.exists(self.checkpoint.current_file):
                logging.info('Reloading `{}` from last checkpoints...'.format(
                    self.checkpoint.current_file))
                self.process_file(self.checkpoint.current_file,
                                  is_old_file=True)
            if len(dir_list) == 0:
                logging.error(
                    'Please specify at least one directory or file to process.'
                )
            # Process each directory / file
            for dir_name in dir_list:
                if os.path.isdir(dir_name):
                    self.process_dir(dir_name)
                elif os.path.isfile(dir_name):
                    self.process_file(dir_name)
                else:
                    logging.warning(
                        '`{}` is not a directory / file; skip.'.format(
                            dir_name))
        # Ctrl + C manually stopped
        except KeyboardInterrupt:
            self.after_process(is_interrupted=True)
        # Other unknown exceptions...
        except Exception as e:
            logging.warning(e)
            self.after_process(is_interrupted=True)
        else:
            self.after_process(is_interrupted=False)

    def after_process(self, is_interrupted: bool) -> None:
        """Deal with opened files, useless files and save records."""
        # Close files, and remove files with zero contents
        head_len = len('\t'.join(config.RECORD_ATTR_LIST)) + 1
        for file in self.out_files.values():
            file.close()
            # Not strictly compare size
            if os.path.getsize(file.name) <= head_len + 100:
                os.remove(file.name)
        # Handle interrupts
        if is_interrupted:
            self.checkpoint.save(config.RECORD_FILE)
            logging.info('Checkpoint saved in `{}`.'.format(
                config.RECORD_FILE))
        # Normal ending, remove record file
        elif os.path.exists(config.RECORD_FILE):
            os.remove(config.RECORD_FILE)
        # Analyse speed
        total_mb = self.bytes_count / self.MILLION
        total_time = time.time() - self.start_time
        avg_speed = total_mb / total_time
        logging.info(
            'Processed {:.2f} MB in {:.2f} s, {:.2f} MB/s on average.'.format(
                total_mb, total_time, avg_speed))
        exit(int(is_interrupted))
Esempio n. 24
0
 def __init__(self):
     self.bytes_count = 0
     self.start_time = 0.0
     self.out_files = {}
     self.checkpoint = Checkpoint(config.VALUE_SET)
     self.init_time()
Esempio n. 25
0
	def testCheckpointWrite(self):
		x = Checkpoint(TestCheckpoint.TEST_DIR)
		x.createCheckpointLog(TestCheckpoint.TEST_KEY)

		x.writeCheckpoint(TestCheckpoint.TEST_KEY, "a", 1)
		x.writeCheckpoint(TestCheckpoint.TEST_KEY, "a", 2)
		x.writeCheckpoint(TestCheckpoint.TEST_KEY, "b", 3)

		x.releaseCheckpointLog(TestCheckpoint.TEST_KEY)

		self.assertEqual(sorted(x.getCheckpointKeys()), ["a", "b"])
		self.assertEqual(x.getCheckpoints("b"), [3])
		self.assertEqual(x.getCheckpoints("a"), [1, 2])
	def __init__(self , config ):
		self.acc = INIT_ACC
		Checkpoint.__init__(self,config)
		self.buff = 0
Esempio n. 27
0
    def train_epochs(self, encoder, decoder, start_epoch,start_step, train_data,
    dev_data, end_epoch, log_file):

        #Prepare constants
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        params = {'batch_size': self.batch_size,
          'shuffle': True,
          'num_workers': 4,
          'drop_last': True}


        #Prepare dataloader and remaining constants.
        training_data   = DataLoader(train_data, **params)
        val_data        = DataLoader(dev_data, **params)
        steps_per_epoch = len(training_data)
        step            = start_step
        tot_steps       = steps_per_epoch*(end_epoch - start_epoch)
        elapsed_steps   = 0

        for epoch in range(start_epoch,end_epoch):
            print("Epoch: {:d}  Step: {:d}".format(epoch,step),file=open(log_file, 'a'))
            start = time.time()
            elapsed_steps = 0
            epoch_loss_total = 0.0
            reconstruction_accuracy_total = 0.0
            loss_total = 0.0
            KL_div_total = 0.0

            for batch in training_data:
                batch = batch.to(device)
                loss, reconstruction_accuracy, KL_div = self.train_batch(step, encoder, decoder, batch, self.inverse_sigmoid(step))
                loss_total += loss
                epoch_loss_total += loss
                reconstruction_accuracy_total += reconstruction_accuracy
                KL_div_total += KL_div
                step += 1
                elapsed_steps += 1

                if step%self.print_every == 0:
                    if elapsed_steps > self.print_every:
                        cnt = self.print_every
                    else:
                        cnt = elapsed_steps
                    loss_avg = loss_total /cnt
                    reconstruction_accuracy_avg = reconstruction_accuracy_total/cnt
                    KL_div_avg = KL_div_total/cnt
                    loss_total = 0
                    reconstruction_accuracy_total = 0
                    KL_div_total = 0
                    print(("Progress: {:.2f}%"
                    " Average Loss: {:2.2f}"
                    " Reconstruction Accuracy: {:2.2f}%"
                    " KL Divergence: {:2.2f}"
                    " ").format((elapsed_steps / steps_per_epoch)*100, loss_avg,reconstruction_accuracy_avg,KL_div_avg),file=open(log_file, 'a'))
                if step%self.checkpoint_every == 0:
                    print("Trying to checkpoint.")
                    Checkpoint( encoder   = encoder,
                                decoder   = decoder,
                                epoch     = epoch,
                                step      = step,
                                optimizer = self.optimizer,
                                scheduler = self.scheduler,
                                samp_rate = self.samp_rate,
                                KL_rate   = self.KL_rate,
                                free_bits = self.free_bits
                                ).save(self.exp_dir)
                    print("Checkpoint successful!")


            print("End of epoch. Time elapsed: " + timer(start, time.time()), file=open(log_file, 'a'))
            print("Average loss for this epoch: {:2.2f} ".format(epoch_loss_total/elapsed_steps), file=open(log_file, 'a'))
            Checkpoint( encoder   = encoder,
                        decoder   = decoder,
                        epoch     = epoch+1,
                        step      = step,
                        optimizer = self.optimizer,
                        scheduler = self.scheduler,
                        samp_rate = self.samp_rate,
                        KL_rate   = self.KL_rate,
                        free_bits = self.free_bits
                        ).save(self.exp_dir)

            #Now, compute validation.
            with torch.no_grad():
                reconstruction_accuracy_val = 0.0
                reconstruction_accuracy_val_nf = 0.0
                val_loss  = 0.0
                val_KL_tot = 0.0
                val_loss_nf = 0.0
                val_KL_tot_nf = 0.0
                count = 0
                for val_batch in val_data:
                    count += 1
                    val_batch = val_batch.to(device)
                    batch_loss, batch_accuracy, val_KL = self.loss(step, encoder,decoder, val_batch, 1)
                    batch_loss_nf, batch_accuracy_nf, val_KL_nf = self.loss(step, encoder, decoder, val_batch, 0)
                    val_loss += batch_loss
                    reconstruction_accuracy_val += batch_accuracy
                    val_KL_tot += val_KL

                    val_loss_nf += batch_loss_nf
                    reconstruction_accuracy_val_nf += batch_accuracy_nf
                    val_KL_tot_nf += val_KL_nf

                reconstruction_accuracy_val /= count
                val_loss /= count
                val_KL_tot /= count

                reconstruction_accuracy_val_nf /= count
                val_loss_nf /= count
                val_KL_tot_nf /= count
                print("Validation results: ", file=open(log_file, 'a'))
                print("Reconstruction Accuracy: {:2.2f}%"
                " Loss (Validation): {:2.2f}"
                " KL Divergence {:2.2f}".format(100*reconstruction_accuracy_val,val_loss,val_KL_tot), file=open(log_file, 'a'))

                print("Reconstruction Accuracy (NF): {:2.2f}%"
                " Loss (NF): {:2.2f}"
                " KL Divergence (NF) {:2.2f}".format(100*reconstruction_accuracy_val_nf,val_loss_nf,val_KL_tot_nf), file=open(log_file, 'a'))
Esempio n. 28
0
    def _train_gan(self):
        """
        TODO: Add in autoencoder to perform dimensionality reduction on data
        TODO: Not working yet - trying to work out good autoencoder model first
        :return:
        """

        criterion = nn.BCELoss()

        discriminator_optimiser = optim.Adam(self.discriminator.parameters(),
                                             lr=0.003,
                                             betas=(0.5, 0.999))
        discriminator_scheduler = optim.lr_scheduler.LambdaLR(
            discriminator_optimiser, lambda epoch: 0.97**epoch)
        discriminator_checkpoint = Checkpoint("discriminator")
        discriminator_epoch = 0
        if discriminator_checkpoint.load():
            discriminator_epoch = self.load_state(discriminator_checkpoint,
                                                  self.discriminator,
                                                  discriminator_optimiser)
        else:
            LOG.info('Discriminator checkpoint not found')

        generator_optimiser = optim.Adam(self.generator.parameters(),
                                         lr=0.003,
                                         betas=(0.5, 0.999))
        generator_scheduler = optim.lr_scheduler.LambdaLR(
            generator_optimiser, lambda epoch: 0.97**epoch)
        generator_checkpoint = Checkpoint("generator")
        generator_epoch = 0
        if generator_checkpoint.load():
            generator_epoch = self.load_state(generator_checkpoint,
                                              self.generator,
                                              generator_optimiser)
        else:
            LOG.info('Generator checkpoint not found')

        if discriminator_epoch is None or generator_epoch is None:
            epoch = 0
            LOG.info(
                "Discriminator or generator failed to load, training from start"
            )
        else:
            epoch = min(generator_epoch, discriminator_epoch)
            LOG.info("Generator loaded at epoch {0}".format(generator_epoch))
            LOG.info("Discriminator loaded at epoch {0}".format(
                discriminator_epoch))
            LOG.info("Training from lowest epoch {0}".format(epoch))

        vis_path = os.path.join(
            os.path.splitext(self.config.FILENAME)[0], "gan",
            str(datetime.now()))
        with Visualiser(vis_path) as vis:
            real_labels = None  # all 1s
            fake_labels = None  # all 0s
            epochs_complete = 0
            while epoch < self.config.MAX_EPOCHS:

                if self.check_requeue(epochs_complete):
                    return  # Requeue needed and training not complete

                for step, (data, noise1,
                           noise2) in enumerate(self.data_loader):
                    batch_size = data.size(0)
                    if real_labels is None or real_labels.size(
                            0) != batch_size:
                        real_labels = self.generate_labels(batch_size, [1.0])
                    if fake_labels is None or fake_labels.size(
                            0) != batch_size:
                        fake_labels = self.generate_labels(batch_size, [0.0])

                    if self.config.USE_CUDA:
                        data = data.cuda()
                        noise1 = noise1.cuda()
                        noise2 = noise2.cuda()

                    # ============= Train the discriminator =============
                    # Pass real noise through first - ideally the discriminator will return 1 #[1, 0]
                    d_output_real = self.discriminator(data)
                    # Pass generated noise through - ideally the discriminator will return 0 #[0, 1]
                    d_output_fake1 = self.discriminator(self.generator(noise1))

                    # Determine the loss of the discriminator by adding up the real and fake loss and backpropagate
                    d_loss_real = criterion(
                        d_output_real, real_labels
                    )  # How good the discriminator is on real input
                    d_loss_fake = criterion(
                        d_output_fake1, fake_labels
                    )  # How good the discriminator is on fake input
                    d_loss = d_loss_real + d_loss_fake
                    self.discriminator.zero_grad()
                    d_loss.backward()
                    discriminator_optimiser.step()

                    # =============== Train the generator ===============
                    # Pass in fake noise to the generator and get it to generate "real" noise
                    # Judge how good this noise is with the discriminator
                    d_output_fake2 = self.discriminator(self.generator(noise2))

                    # Determine the loss of the generator using the discriminator and backpropagate
                    g_loss = criterion(d_output_fake2, real_labels)
                    self.discriminator.zero_grad()
                    self.generator.zero_grad()
                    g_loss.backward()
                    generator_optimiser.step()

                    vis.step(d_loss_real.item(), d_loss_fake.item(),
                             g_loss.item())

                    # Report data and save checkpoint
                    fmt = "Epoch [{0}/{1}], Step[{2}/{3}], d_loss_real: {4:.4f}, d_loss_fake: {5:.4f}, g_loss: {6:.4f}"
                    LOG.info(
                        fmt.format(epoch + 1, self.config.MAX_EPOCHS, step + 1,
                                   len(self.data_loader), d_loss_real,
                                   d_loss_fake, g_loss))

                epoch += 1
                epochs_complete += 1

                discriminator_checkpoint.set(
                    self.discriminator.state_dict(),
                    discriminator_optimiser.state_dict(), epoch).save()
                generator_checkpoint.set(self.generator.state_dict(),
                                         generator_optimiser.state_dict(),
                                         epoch).save()
                vis.plot_training(epoch)

                data, noise1, _ = iter(self.data_loader).__next__()
                if self.config.USE_CUDA:
                    data = data.cuda()
                    noise1 = noise1.cuda()
                vis.test(epoch, self.data_loader.get_input_size_first(),
                         self.discriminator, self.generator, noise1, data)

                generator_scheduler.step(epoch)
                discriminator_scheduler.step(epoch)

                LOG.info("Learning rates: d {0} g {1}".format(
                    discriminator_optimiser.param_groups[0]["lr"],
                    generator_optimiser.param_groups[0]["lr"]))

        LOG.info("GAN Training complete")
Esempio n. 29
0
    def _train_autoencoder(self):
        """
        Main training loop for the autencoder.
        This function will return False if:
        - Loading the autoencoder succeeded, but the NN model did not load the state dicts correctly.
        - The script needs to be re-queued because the NN has been trained for REQUEUE_EPOCHS
        :return: True if training was completed, False if training needs to continue.
        :rtype bool
        """

        criterion = nn.SmoothL1Loss()

        optimiser = optim.Adam(self.generator.parameters(),
                               lr=0.00003,
                               betas=(0.5, 0.999))
        checkpoint = Checkpoint("autoencoder")
        epoch = 0
        if checkpoint.load():
            epoch = self.load_state(checkpoint, self.autoencoder, optimiser)
            if epoch is not None and epoch >= self.config.MAX_AUTOENCODER_EPOCHS:
                LOG.info("Autoencoder already trained")
                return True
            else:
                LOG.info(
                    "Autoencoder training beginning from epoch {0}".format(
                        epoch))
        else:
            LOG.info('Autoencoder checkpoint not found. Training from start')

        # Train autoencoder
        self._autoencoder.set_mode(Autoencoder.Mode.AUTOENCODER)

        vis_path = os.path.join(
            os.path.splitext(self.config.FILENAME)[0], "autoencoder",
            str(datetime.now()))
        with Visualiser(vis_path) as vis:
            epochs_complete = 0
            while epoch < self.config.MAX_AUTOENCODER_EPOCHS:

                if self.check_requeue(epochs_complete):
                    return False  # Requeue needed and training not complete

                for step, (data, _, _) in enumerate(self.data_loader):
                    if self.config.USE_CUDA:
                        data = data.cuda()

                    if self.config.ADD_DROPOUT:
                        # Drop out parts of the input, but compute loss on the full input.
                        out = self.autoencoder(nn.functional.dropout(
                            data, 0.5))
                    else:
                        out = self.autoencoder(data)

                    loss = criterion(out.cpu(), data.cpu())
                    self.autoencoder.zero_grad()
                    loss.backward()
                    optimiser.step()

                    vis.step_autoencoder(loss.item())

                    # Report data and save checkpoint
                    fmt = "Epoch [{0}/{1}], Step[{2}/{3}], loss: {4:.4f}"
                    LOG.info(
                        fmt.format(epoch + 1,
                                   self.config.MAX_AUTOENCODER_EPOCHS, step,
                                   len(self.data_loader), loss))

                epoch += 1
                epochs_complete += 1

                checkpoint.set(self.autoencoder.state_dict(),
                               optimiser.state_dict(), epoch).save()

                LOG.info("Plotting autoencoder progress")
                vis.plot_training(epoch)
                data, _, _ = iter(self.data_loader).__next__()
                vis.test_autoencoder(epoch, self.autoencoder, data.cuda())

        LOG.info("Autoencoder training complete")
        return True  # Training complete
Esempio n. 30
0
def main(env_name, num_episodes, gamma, lam, kl_targ, batch_size, nprocs,
         policy_hid_list, valfunc_hid_list, gpu_pct, restore_path, animate,
         submit):
    """ Main training loop

    Args:
        env_name: OpenAI Gym environment name, e.g. 'Hopper-v1'
        num_episodes: maximum number of episodes to run
        gamma: reward discount factor (float)
        lam: lambda from Generalized Advantage Estimate
        kl_targ: D_KL target for policy update [D_KL(pi_old || pi_new)
        batch_size: number of episodes per policy training batch
    """
    # killer = GracefulKiller()

    env, obs_dim, act_dim = init_osim(animate)
    env.seed(111 + mpi_util.rank)
    mpi_util.set_global_seeds(111 + mpi_util.rank)

    obs_dim += 1  # add 1 to obs dimension for time step feature (see run_episode())
    now = datetime.utcnow().strftime(
        "%b-%d_%H:%M:%S")  # create unique directories
    if mpi_util.rank == 0:
        #aigym_path = os.path.join('/tmp', env_name, now)
        #env = wrappers.Monitor(env, aigym_path, force=True)
        logger = Logger(logname=env_name, now=now)

    episode = 0

    checkpoint = Checkpoint("saves", now)
    # restore from checkpoint?
    if restore_path:
        (policy, val_func, scaler, episode, obs_dim, act_dim,
         kl_targ) = checkpoint.restore(restore_path)
    else:
        policy = Policy(obs_dim, act_dim, kl_targ)
        val_func = NNValueFunction(obs_dim)
        scaler = Scaler(obs_dim)

        if mpi_util.rank == 0:
            # run a few episodes (on node 0) of untrained policy to initialize scaler:
            trajectories = run_policy(env, policy, scaler, episodes=5)

            unscaled = np.concatenate(
                [t['unscaled_obs'] for t in trajectories])
            scaler.update(
                unscaled)  # update running statistics for scaling observations

        # broadcast policy weights, scaler, val_func
        (policy, scaler, val_func) = mpi_util.broadcast_policy_scaler_val(
            policy, scaler, val_func)

        if mpi_util.rank == 0:
            checkpoint.save(policy, val_func, scaler, episode)

    if animate:
        observes, actions, rewards, unscaled_obs = run_episode(env,
                                                               policy,
                                                               scaler,
                                                               animate=animate)
        exit(0)

    if submit:
        # Settings
        #remote_base = 'http://grader.crowdai.org:1729'
        remote_base = 'http://grader.crowdai.org:1730'
        token = 'a83412a94593cae3a491f3ee28ff44e1'

        client = Client(remote_base)

        # Create environment
        observation = client.env_create(token)
        step = 0.0
        observes, actions, rewards, unscaled_obs = [], [], [], []
        scale, offset = scaler.get()
        scale[-1] = 1.0  # don't scale time step feature
        offset[-1] = 0.0  # don't offset time step feature

        # Run a single step
        #
        # The grader runs 3 simulations of at most 1000 steps each. We stop after the last one
        while True:
            obs = np.array(observation).astype(np.float32).reshape((1, -1))
            print("OBSERVATION TYPE:", type(obs), obs.shape)
            print(obs)
            obs = np.append(obs, [[step]], axis=1)  # add time step feature
            unscaled_obs.append(obs)
            obs = (obs - offset) * scale  # center and scale observations
            observes.append(obs)

            action = policy.sample(obs).astype(np.float32).reshape((-1, 1))
            print("ACTION TYPE:", type(action), action.shape)
            print(action)
            actions.append(action)

            [observation, reward, done,
             info] = client.env_step(action.tolist())
            print("step:", step, "reward:", reward)

            if not isinstance(reward, float):
                reward = np.asscalar(reward)
            rewards.append(reward)
            step += 1e-3  # increment time step feature

            if done:
                print(
                    "================================== RESTARTING ================================="
                )
                observation = client.env_reset()
                step = 0.0
                observes, actions, rewards, unscaled_obs = [], [], [], []
                scale, offset = scaler.get()
                scale[-1] = 1.0  # don't scale time step feature
                offset[-1] = 0.0  # don't offset time step feature
                if not observation:
                    break

        client.submit()
        exit(0)

    ######

    worker_batch_size = int(batch_size / mpi_util.nworkers)  # HACK
    if (worker_batch_size * mpi_util.nworkers != batch_size):
        print("batch_size:", batch_size, " is not divisible by nworkers:",
              mpi_util.nworkers)
        exit(1)

    batch = 0
    while episode < num_episodes:
        if mpi_util.rank == 0 and batch > 0 and batch % 10 == 0:
            checkpoint.save(policy, val_func, scaler, episode)
        batch = batch + 1

        trajectories = run_policy(env,
                                  policy,
                                  scaler,
                                  episodes=worker_batch_size)
        trajectories = mpi_util.gather_trajectories(trajectories)

        if mpi_util.rank == 0:
            # concatentate trajectories into one list
            trajectories = list(itertools.chain.from_iterable(trajectories))
            print("did a batch of ", len(trajectories), " trajectories")
            print([t['rewards'].sum() for t in trajectories])

            episode += len(trajectories)
            add_value(trajectories,
                      val_func)  # add estimated values to episodes
            add_disc_sum_rew(trajectories,
                             gamma)  # calculated discounted sum of Rs
            add_gae(trajectories, gamma, lam)  # calculate advantage

            # concatenate all episodes into single NumPy arrays
            observes, actions, advantages, disc_sum_rew = build_train_set(
                trajectories)

            # add various stats to training log:
            logger.log({
                '_MeanReward':
                np.mean([t['rewards'].sum() for t in trajectories]),
                'Steps':
                np.sum([t['observes'].shape[0] for t in trajectories])
            })
            log_batch_stats(observes, actions, advantages, disc_sum_rew,
                            logger, episode)

            policy.update(observes, actions, advantages,
                          logger)  # update policy
            val_func.fit(observes, disc_sum_rew,
                         logger)  # update value function

            unscaled = np.concatenate(
                [t['unscaled_obs'] for t in trajectories])
            scaler.update(
                unscaled)  # update running statistics for scaling observations

            logger.write(
                display=True)  # write logger results to file and stdout

        # if mpi_util.rank == 0 and killer.kill_now:
        #     if input('Terminate training (y/[n])? ') == 'y':
        #         break
        #     killer.kill_now = False

        # broadcast policy weights, scaler, val_func
        (policy, scaler, val_func) = mpi_util.broadcast_policy_scaler_val(
            policy, scaler, val_func)

    if mpi_util.rank == 0: logger.close()
    policy.close_sess()
    if mpi_util.rank == 0: val_func.close_sess()
Esempio n. 31
0
	def setUp(self):
		os.mkdir(TestActionsStateMachine.TEST_DIR)
		self.cp = Checkpoint(TestActionsStateMachine.TEST_DIR)
		self.cp.createCheckpointLog(TestActionsStateMachine.TEST_KEY)
Esempio n. 32
0
class TestActionsStateMachine(unittest.TestCase):
	TEST_DIR = "/tmp/state_test"
	TEST_KEY = "test_key"

	def setUp(self):
		os.mkdir(TestActionsStateMachine.TEST_DIR)
		self.cp = Checkpoint(TestActionsStateMachine.TEST_DIR)
		self.cp.createCheckpointLog(TestActionsStateMachine.TEST_KEY)

	def tearDown(self):
		os.system("rm -rf " + TestActionsStateMachine.TEST_DIR)

	def testBasic(self):
		actions = flickr_uploader.get_actions("photo", "set", self.cp)[0]

		self.assertTrue(actions & flickr_uploader.UPLOAD_PHOTO)
		self.assertTrue(actions & flickr_uploader.CREATE_SET)
		self.assertFalse(actions & flickr_uploader.ADD_TO_SET)

	def testSetCreation(self):
		self.cp.writeCheckpoint(TestActionsStateMachine.TEST_KEY, "photo",
			{"status": flickr_uploader.PHOTO_UPLOADED, "photo_id": 3})

		res = flickr_uploader.get_actions("photo", "set", self.cp)
		actions = res[0]
		data = res[1]

		self.assertTrue(actions & flickr_uploader.CREATE_SET)
		self.assertFalse(actions & flickr_uploader.UPLOAD_PHOTO)
		self.assertFalse(actions & flickr_uploader.ADD_TO_SET)

		self.assertEquals(3, data["photo_id"])

	def testUploadAndAdd(self):
		self.cp.writeCheckpoint(TestActionsStateMachine.TEST_KEY, "photo",
			{"status": flickr_uploader.PHOTO_UPLOADED, "photo_id": 3})

		self.cp.writeCheckpoint(TestActionsStateMachine.TEST_KEY, "set",
			{"status": flickr_uploader.SET_CREATED, "set_id": 3})

		self.cp.writeCheckpoint(TestActionsStateMachine.TEST_KEY, "photo",
			{"status": flickr_uploader.ADDED_TO_SET})

		res = flickr_uploader.get_actions("photo2", "set", self.cp)
		actions = res[0]
		data = res[1]

		self.assertFalse(actions & flickr_uploader.CREATE_SET)
		self.assertTrue(actions & flickr_uploader.UPLOAD_PHOTO)
		self.assertTrue(actions & flickr_uploader.ADD_TO_SET)

		self.assertEquals(3, data["set_id"])

	def testAddToSet(self):
		self.cp.writeCheckpoint(TestActionsStateMachine.TEST_KEY, "photo",
			{"status": flickr_uploader.PHOTO_UPLOADED, "photo_id": 3})

		self.cp.writeCheckpoint(TestActionsStateMachine.TEST_KEY, "set",
			{"status": flickr_uploader.SET_CREATED, "set_id": 3})

		self.cp.writeCheckpoint(TestActionsStateMachine.TEST_KEY, "photo",
			{"status": flickr_uploader.ADDED_TO_SET})

		self.cp.writeCheckpoint(TestActionsStateMachine.TEST_KEY, "photo2",
			{"status": flickr_uploader.PHOTO_UPLOADED, "photo_id": 4})
		
		res = flickr_uploader.get_actions("photo2", "set", self.cp)
		actions = res[0]
		data = res[1]

		self.assertFalse(actions & flickr_uploader.CREATE_SET)
		self.assertFalse(actions & flickr_uploader.UPLOAD_PHOTO)
		self.assertTrue(actions & flickr_uploader.ADD_TO_SET)

		self.assertEquals(3, data["set_id"])
		self.assertEquals(4, data["photo_id"])
Esempio n. 33
0
#                   input_dir=args.input_dir,
#                   labeled_only=True)

num_classes = 10
torch_device = torch.device('cuda')
# create new model with random parameters
model = Model(ndf=args.ndf,
              n_classes=num_classes,
              n_rkhs=args.n_rkhs,
              tclip=args.tclip,
              n_depth=args.n_depth,
              enc_size=enc_size,
              use_bn=(args.use_bn == 1))
model.init_weights(init_scale=1.0)
# restore model parameters from a checkpoint if requested
checkpoint = Checkpoint(model, args.cpt_load_path, args.output_dir,
                        args.cpt_name)
model = model.to(torch_device)

# select which type of training to do
task = train_classifiers if args.classifiers else train_self_supervised

ckpt = torch.load('/root/amdim-public/runs_stl64_norm_BN/cifar_amdim_cpt.pth')
params = ckpt['model']
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in params.items():
    name = k.replace("module.", "")
    new_state_dict[name] = v
# print(new_state_dict)
model.load_state_dict(new_state_dict)
# model.load_state_dict(params)
Esempio n. 34
0
 def from_dict(_dict):
     gpid = _dict["gpid"]
     checkpoints = [Checkpoint.from_dict(checkpoint) for checkpoint in _dict.get("checkpoints", [])]
     players = [Player.from_dict(player) for player in _dict.get("players", [])]
     min_players, max_players = _dict.get("min_players", 8), _dict.get("max_players", 20)
     return GameParameters(gpid, checkpoints, players, min_players, max_players)
Esempio n. 35
0
def main():
    colorama.init()
    print("Thank you for using canvas_grab!")
    print(
        f"You are using version {VERSION}. If you have any questions, please file an issue at {Fore.BLUE}https://github.com/skyzh/canvas_grab/issues{Style.RESET_ALL}"
    )
    print(
        f"You may review {Fore.GREEN}README(_zh-hans).md{Style.RESET_ALL} and {Fore.GREEN}LICENSE{Style.RESET_ALL} shipped with this release"
    )
    config.load_config()
    if config.ENABLE_VIDEO:
        print(
            f"Note: You've enabled video download. You should install the required tools yourself."
        )
        print(
            f"      This is an experimental functionality and takes up large amount of bandwidth. {Fore.RED}Use at your own risk.{Style.RESET_ALL}"
        )
    canvas = Canvas(config.API_URL, config.API_KEY)

    try:
        print(f'{Fore.BLUE}Logging in...{Style.RESET_ALL}')
        print(
            f"{Fore.GREEN}Logged in to {config.API_URL} as {canvas.get_current_user()}{Style.RESET_ALL}"
        )
    except canvasapi.exceptions.InvalidAccessToken:
        print(
            f"{Fore.RED}Invalid access token, please check your config.API_KEY in config file"
        )
        if is_windows():
            # for windows double-click user
            input()
        exit()

    try:
        global checkpoint
        checkpoint = Checkpoint(config.CHECKPOINT_FILE)
        checkpoint.load()
    except FileNotFoundError:
        print(f"{Fore.RED}No checkpoint found{Style.RESET_ALL}")

    courses = [
        course for course in canvas.get_courses() if hasattr(course, "name")
    ]
    if config.WHITELIST_CANVAS_ID:
        print(f"{Fore.BLUE}Whilelist mode enabled{Style.RESET_ALL}")
        courses = [
            course for course in courses
            if course.id in config.WHITELIST_CANVAS_ID
        ]
    try:
        for course in courses:
            if course.start_at:
                delta = -(datetime.strptime(
                    course.start_at, r'%Y-%m-%dT%H:%M:%S%z').replace(
                        tzinfo=None) - datetime.now()).days
            else:
                delta = 0
            if course.id in config.IGNORED_CANVAS_ID:
                print(
                    f"{Fore.CYAN}Explicitly Ignored Course: {course.course_code}{Style.RESET_ALL}"
                )
            elif config.RETAIN_COURSE_DAYS != 0 and delta > config.RETAIN_COURSE_DAYS:
                print(
                    f"{Fore.CYAN}Outdated Course: {course.course_code}{Style.RESET_ALL}"
                )
            else:
                try:
                    process_course(course)
                except KeyboardInterrupt:
                    raise
                except canvasapi.exceptions.Unauthorized as e:
                    print(
                        f"{Fore.RED}An error occoured when processing this course (unauthorized): {e}{Style.RESET_ALL}"
                    )
                except canvasapi.exceptions.ResourceDoesNotExist as e:
                    print(
                        f"{Fore.RED}An error occoured when processing this course (resourse not exist): {e}{Style.RESET_ALL}"
                    )
        if config.SCAN_STALE_FILE:
            scan_stale_files(courses)
    except KeyboardInterrupt:
        print(
            f"{Fore.RED}Terminated due to keyboard interrupt.{Style.RESET_ALL}"
        )

    checkpoint.dump()

    if new_files_list:
        print(
            f"{Fore.GREEN}{len(new_files_list)} new or updated files:{Style.RESET_ALL}"
        )
        for f in new_files_list:
            print(f"    {f}")

    if updated_files_list:
        print(
            f"{Fore.GREEN}{len(updated_files_list)} files have a more recent version on Canvas:{Style.RESET_ALL}"
        )
        for f in updated_files_list:
            print(f"    {f}")

    if failure_file_list:
        print(
            f"{Fore.YELLOW}{len(failure_file_list)} files are not downloaded:{Style.RESET_ALL}"
        )
        for f in failure_file_list:
            print(f"    {f}")

    if not new_files_list and not updated_files_list:
        print("All files up to date")

    if config.ENABLE_VIDEO:
        print(
            f"{Fore.GREEN}{len(ffmpeg_commands)} videos resolved{Style.RESET_ALL}"
        )
        print(
            f"Please run the automatically-generated script {Fore.BLUE}download_video.(sh/ps1){Style.RESET_ALL} to download all videos."
        )
        with open("download_video.sh", 'w') as file:
            file.write("\n".join(ffmpeg_commands))
        with open("download_video.ps1", 'w') as file:
            file.write("\n".join(ffmpeg_commands))

    if config.ALLOW_VERSION_CHECK:
        check_latest_version()

    print(f"{Fore.GREEN}Done.{Style.RESET_ALL}")

    if is_windows():
        # for windows double-click user
        input()
Esempio n. 36
0
def training(edit_net,
             nepochs,
             args,
             vocab,
             print_every=100,
             check_every=500,
             test=False):
    if test:
        print(args.data_path + 'test.df.filtered.pos')
        eval_dataset = data.Dataset(
            args.data_path + 'test.df.filtered.pos')  # load eval dataset
    else:
        print(args.data_path + 'val.df.filtered.pos')
    eval_dataset = data.Dataset(args.data_path +
                                'val.df.filtered.pos')  # load eval dataset
    evaluator = Evaluator(
        loss=nn.NLLLoss(ignore_index=vocab.w2i['PAD'], reduction='none'))
    editnet_optimizer = torch.optim.Adam(edit_net.parameters(),
                                         lr=1e-3,
                                         weight_decay=1e-6)
    # scheduler = MultiStepLR(abstract_optimizer, milestones=[20,30,40], gamma=0.1)
    # abstract_scheduler = ReduceLROnPlateau(abstract_optimizer, mode='max')

    # uncomment this part to re-weight different operations
    # NLL_weight = reweight_global_loss(args.w_add, args.w_keep, args.w_del)
    # NLL_weight_t = torch.from_numpy(NLL_weight).float().cuda()
    # editnet_criterion = nn.NLLLoss(weight=NLL_weight_t, ignore_index=vocab.w2i['PAD'], reduce=False)
    editnet_criterion = nn.NLLLoss(ignore_index=vocab.w2i['PAD'],
                                   reduction='none')

    best_eval_loss = 0.  # init statistics
    print_loss = []  # Reset every print_every

    for epoch in range(nepochs):
        # scheduler.step()
        #reload training for every epoch
        if os.path.isfile(args.data_path + 'train.df.filtered.pos'):
            train_dataset = data.Dataset(args.data_path +
                                         'train.df.filtered.pos')
        else:  # iter chunks and vocab_data
            train_dataset = data.Datachunk(args.data_path +
                                           'train.df.filtered.pos')

        for i, batch_df in train_dataset.batch_generator(
                batch_size=args.batch_size, shuffle=True):

            #     time1 = time.time()
            prepared_batch, syn_tokens_list = data.prepare_batch(
                batch_df, vocab, args.max_seq_len)  #comp,scpn,simp

            # a batch of complex tokens in vocab ids, sorted in descending order
            org_ids = prepared_batch[0]
            org_lens = org_ids.ne(0).sum(1)
            org = sort_by_lens(
                org_ids, org_lens
            )  # inp=[inp_sorted, inp_lengths_sorted, inp_sort_order]
            # a batch of pos-tags in pos-tag ids for complex
            org_pos_ids = prepared_batch[1]
            org_pos_lens = org_pos_ids.ne(0).sum(1)
            org_pos = sort_by_lens(org_pos_ids, org_pos_lens)

            out = prepared_batch[2][:, :]
            tar = prepared_batch[2][:, 1:]

            simp_ids = prepared_batch[3]

            editnet_optimizer.zero_grad()
            output = edit_net(org, out, org_ids, org_pos, simp_ids)
            ##################calculate loss
            tar_lens = tar.ne(0).sum(1).float()
            tar_flat = tar.contiguous().view(-1)
            loss = editnet_criterion(output.contiguous().view(-1, vocab.count),
                                     tar_flat).contiguous()
            loss[tar_flat == 1] = 0  #remove loss for UNK
            loss = loss.view(tar.size())
            loss = loss.sum(1).float()
            loss = loss / tar_lens
            loss = loss.mean()

            print_loss.append(loss.item())
            loss.backward()

            torch.nn.utils.clip_grad_norm_(edit_net.parameters(), 1.)
            editnet_optimizer.step()

            if i % print_every == 0:
                log_msg = 'Epoch: %d, Step: %d, Loss: %.4f' % (
                    epoch, i, np.mean(print_loss))
                print_loss = []
                print(log_msg)

                # Checkpoint
            if i % check_every == 0:
                edit_net.eval()

                val_loss, bleu_score, sari, sys_out = evaluator.evaluate(
                    eval_dataset, vocab, edit_net, args)
                log_msg = "epoch %d, step %d, Dev loss: %.4f, Bleu score: %.4f, Sari: %.4f \n" % (
                    epoch, i, val_loss, bleu_score, sari)
                print(log_msg)

                if val_loss < best_eval_loss:
                    best_eval_loss = val_loss
                Checkpoint(
                    model=edit_net,
                    opt=editnet_optimizer,
                    epoch=epoch,
                    step=i,
                ).save(args.store_dir)
                print("checked after %d steps" % i)

                edit_net.train()

    print(edit_net)
    return edit_net
Esempio n. 37
0
def main(_):
    checkpoint = Checkpoint(FLAGS.checkpoint_dir)
    utils.exists_or_mkdir(FLAGS.sample_dir)
    utils.exists_or_mkdir(FLAGS.log_dir)
    summaryWriter = tensorboardX.SummaryWriter(log_dir = FLAGS.log_dir)#torch.utils.tensorboard.SummaryWriter(log_dir = FLAGS.log_dir)

    logger.info('[Params] lr:%f, size:%d, dataset:%s, av_gen:%d, n_disc:%d'%
                (FLAGS.learning_rate, FLAGS.output_size, FLAGS.dataset, int(FLAGS.use_averaged_gen), FLAGS.n_discriminator))

    #dataset
    z_shape = (FLAGS.z_dim,)
    image_size = (FLAGS.output_size, FLAGS.output_size)
    image_shape = (3,) + image_size

    ds = dataset.datasets.from_name(name=FLAGS.dataset, data_folder=FLAGS.data_folder,
                                    output_size=image_size)

    batch = batch_gen.BatchWithNoise(ds, batch_size=FLAGS.batch_size, z_shape=z_shape,num_workers=10)

    #initialize device
    device = utils.get_torch_device()

    #model
    nn_model = models.model_factory.create_model(FLAGS.model_name,
                                                 device=device,
                                                 image_shape=image_shape,
                                                 z_shape=z_shape,
                                                 use_av_gen=FLAGS.use_averaged_gen,
                                                 g_tanh=False)
    nn_model.register_checkpoint(checkpoint)

    loss = gan_loss.js_loss()
    #lambd = lambda_scheduler.Constant(0.1)
    lambd = lambda_scheduler.ThresholdAnnealing(1000., threshold=loss.lambda_switch_level, min_switch_step=FLAGS.lambda_switch_steps, verbose=True)
    checkpoint.register('lambda', lambd, True)

    trainer = Trainer(model=nn_model, batch=batch, loss=loss, lr=FLAGS.learning_rate,
                      reg='gp', lambd=lambd)
    trainer.sub_batches = FLAGS.batch_per_update
    trainer.register_checkpoint(checkpoint)

    it_start = checkpoint.load(FLAGS.checkpoint_it_to_load)

    trainer.update_lr()

    ##========================= LOAD CONTEXT ================================##
    context_path = os.path.join(FLAGS.checkpoint_dir, 'context.npz')
    
    sample_seed = None
    if os.path.exists(context_path):
        sample_seed = np.load(context_path)['z']
        if sample_seed.shape[0] != FLAGS.sample_size or sample_seed.shape[1] != FLAGS.z_dim:
            sample_seed = None
            logger.info('Invalid sample seed')
        else:
            logger.info('Sample seed loaded')
    
    if sample_seed is None:
        sample_seed = batch.sample_z(FLAGS.sample_size).data.numpy()
        np.savez(context_path, z = sample_seed)

    ##========================= TRAIN MODELS ================================##
    batches_per_epoch = 10000
    total_time = 0

    bLambdaSwitched = (it_start == 0)
    n_too_good_d = []

    number_of_iterations = FLAGS.epoch*batches_per_epoch
    for it in range(number_of_iterations):
        start_time = time.time()
        iter_counter = it + it_start

        # updates the discriminator
        #if iter_counter < 25 or iter_counter % 500 == 0:
        #    d_iter = 20
        #else:
        #    d_iter = 5
        if bLambdaSwitched:
            #if lambda was switched we want to keep discriminator optimal
            logger.info('[!] Warming up discriminator')
            d_iter = 25
        else:
            d_iter = FLAGS.n_discriminator
#
        errD, s, errG, b_too_good_D = trainer.update(d_iter, 1)

        summaryWriter.add_scalar('d_loss', errD, iter_counter)
        summaryWriter.add_scalar('slope', s, iter_counter)
        summaryWriter.add_scalar('g_loss', errG, iter_counter)
        summaryWriter.add_scalar('loss', errD + float(lambd) * s**2, iter_counter)
        summaryWriter.add_scalar('lambda', float(lambd), iter_counter)

        #updating lambda
        n_too_good_d.append(b_too_good_D)
        if len(n_too_good_d) > 20:
            del n_too_good_d[0]               
                
        bLambdaSwitched = lambd.update(errD)
        if not bLambdaSwitched and sum(n_too_good_d) > 10:
            bLambdaSwitched = lambd.switch()

        end_time = time.time()

        iter_time = end_time - start_time
        total_time += iter_time

        logger.info("[%2d/%2d] time: %4.4f, d_loss: %.8f, s: %.4f, g_loss: %.8f" % (iter_counter, it_start + number_of_iterations, iter_time, errD, s, errG))

        if np.mod(iter_counter, FLAGS.sample_step) == 0 and it > 0:
            n = int(np.sqrt(FLAGS.sample_size))

            img = trainer.sample(sample_seed)
            img = img.data.cpu()

            img_tb = utils.image_to_tensorboard(torchvision.utils.make_grid(img, n))
            summaryWriter.add_image('samples',img_tb, iter_counter)

            utils.save_images(img.data.cpu().numpy(), [n, n], './{}/train_{:02d}.png'.format(FLAGS.sample_dir, iter_counter))

        if np.mod(iter_counter, FLAGS.save_step) == 0 and it > 0:
            checkpoint.save(iter_counter)

    checkpoint.save(iter_counter)
Esempio n. 38
0
    tgt = TargetField()
    max_len = 150

    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len * 3

    dev = torchtext.data.TabularDataset(path=opt.dev_path,
                                        format='tsv',
                                        fields=[('src', src), ('tgt', tgt)],
                                        filter_pred=len_filter)

    logging.info("loading checkpoint from {}".format(
        os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME,
                     opt.load_checkpoint)))
    checkpoint_path = os.path.join(opt.expt_dir,
                                   Checkpoint.CHECKPOINT_DIR_NAME,
                                   opt.load_checkpoint)
    checkpoint = Checkpoint.load(checkpoint_path)
    seq2seq = checkpoint.model
    src.vocab = checkpoint.input_vocab
    tgt.vocab = checkpoint.output_vocab

    weight = torch.ones(len(tgt.vocab))
    pad = tgt.vocab.stoi[tgt.pad_token]
    loss = Perplexity(weight, pad)
    if torch.cuda.is_available():
        loss.cuda()
    evaluator = Evaluator(loss=loss, batch_size=32)
    accuracy = evaluator.test(seq2seq, dev)
    print(accuracy)
Esempio n. 39
0
import utility
import data
import model
import loss
from option import args
from checkpoint import Checkpoint
from trainer import Trainer

utility.set_seed(args.seed)   #  设置随机种子,方便结果复现
checkpoint = Checkpoint(args)

if checkpoint.ok:
    loader = data.Data(args)
    model = model.Model(args, checkpoint)
    loss = loss.Loss(args, checkpoint) if not args.test_only else None
    t = Trainer(args, loader, model, loss, checkpoint)
    while not t.terminate():
        t.train()
        t.test()
    checkpoint.done()

Esempio n. 40
0
	def testCheckpointRestart(self):
		x = Checkpoint(TestCheckpoint.TEST_DIR)
		x.createCheckpointLog(TestCheckpoint.TEST_KEY)
		x.createCheckpointLog(TestCheckpoint.TEST_KEY2)

		x.writeCheckpoint(TestCheckpoint.TEST_KEY, "a", 1)
		x.writeCheckpoint(TestCheckpoint.TEST_KEY, "a", 2)
		x.writeCheckpoint(TestCheckpoint.TEST_KEY2, "b", 3)

		x.releaseCheckpointLog(TestCheckpoint.TEST_KEY)
		x.releaseCheckpointLog(TestCheckpoint.TEST_KEY2)
		
		y = Checkpoint(TestCheckpoint.TEST_DIR)
		self.assertEqual(y.getCheckpoints("b"), [3])
		self.assertEqual(y.getCheckpoints("a"), [1, 2])